Skip to content

Instantly share code, notes, and snippets.

@omeganoob
Last active July 17, 2025 13:41
Show Gist options
  • Save omeganoob/d5d052acd8869791e3dead893f519bed to your computer and use it in GitHub Desktop.
Save omeganoob/d5d052acd8869791e3dead893f519bed to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import glob
import json
import math
import os
import warnings
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from scipy.special import logsumexp
from tensorflow.keras.preprocessing import image
# Test directory
TEST_DIR = "content/test"
# OOD Detection Configuration
DEFAULT_TEMPERATURE = 1
DEFAULT_CONFIDENCE_THRESHOLD = 0.95
ENABLE_OOD_DETECTION = True
DEFAULT_PERTURBATION_CONFIG = {
'gaussian_noise': {
'enabled': True,
'std': 0.005, # reduced
'num_samples': 10 # increased
},
'uniform_noise': {
'enabled': False,
'scale': 0.005,
'num_samples': 10
},
'dropout': {
'enabled': True,
'rate': 0.03, # reduced
'num_samples': 10
},
'mixup': {
'enabled': False,
'alpha': 0.1,
'num_samples': 2
}
}
DEFAULT_OOD_CONFIG = {
'enabled': True,
'temperature': 1.0,
'confidence_threshold': 0.5,
'entropy_threshold': 1.0,
'variance_threshold': 0.1,
'use_uncertainty': True,
'use_energy': True,
'energy_threshold': 10.0,
'aggregation_method': 'mean'
}
def discover_trained_models():
"""Discover all trained models in the content directory"""
config_files = glob.glob("./content/model_config_*.json")
models = []
for config_file in config_files:
try:
with open(config_file, 'r') as f:
config = json.load(f)
# Check if model files exist
model_exists = (
os.path.exists(config['model_export_dir'] + ".keras") or
os.path.exists(config['model_export_dir'] + "_keras") or
os.path.exists(config['model_export_dir'])
)
if model_exists:
models.append({
'config_file': config_file,
'config': config,
'display_name': f"{config['model_name'].upper()} (img_size: {config['img_size']})"
})
except Exception as e:
print(f"Error reading config file {config_file}: {e}")
continue
return models
def select_model_for_testing():
"""Interactive model selection for testing"""
models = discover_trained_models()
if not models:
print("No trained models found!")
print("Please run the training script first.")
return None
print("Available trained models:")
print("=" * 50)
for i, model_info in enumerate(models, 1):
print(f"{i:2d}. {model_info['display_name']}")
print(f" Directory suffix: {model_info['config']['dir_suffix']}")
print("=" * 50)
while True:
try:
choice = input(f"Select model to test (1-{len(models)}): ").strip()
if choice.isdigit():
choice_idx = int(choice) - 1
if 0 <= choice_idx < len(models):
selected_model = models[choice_idx]
break
else:
print(f"Please enter a number between 1 and {len(models)}")
continue
print("Please enter a valid number.")
except KeyboardInterrupt:
print("\nOperation cancelled.")
return None
except Exception as e:
print(f"Error: {e}")
continue
config = selected_model['config']
print(f"\nSelected: {config['model_name'].upper()}")
print(f"Image Size: {config['img_size']}")
print(f"Directory Suffix: {config['dir_suffix']}")
return config
def load_model_and_classes(config):
"""Load the trained model and class names based on configuration"""
model = None
model_type = None
model_export_dir = config['model_export_dir']
class_names_file = config['class_names_file']
# Try different model formats in order of preference
model_paths_to_try = [
(model_export_dir + ".keras", "keras", "Keras .keras file"),
(model_export_dir + "_keras", "keras", "Keras directory format"),
(model_export_dir, "savedmodel", "SavedModel format")
]
for model_path, m_type, description in model_paths_to_try:
try:
print(f"Trying to load {description} from {model_path}...")
if m_type == "keras":
model = tf.keras.models.load_model(model_path)
else: # savedmodel
model = tf.saved_model.load(model_path)
model_type = m_type
print(f"{description} loaded successfully!")
break
except Exception as e:
print(f"Failed to load {description}: {e}")
continue
if model is None:
print("Failed to load any model format!")
return None, None, None, None
# Load class names
try:
with open(class_names_file, 'r') as f:
class_names = json.load(f)
print(f"Class names loaded: {class_names}")
return model, class_names, model_type, config
except Exception as e:
print(f"Error loading class names: {e}")
return model, None, model_type, config
def preprocess_image(image_path, img_size, suffix='_resized', resize_method='bilinear', save_resized=False):
try:
target_img = image.load_img(image_path) # Remove target_size
img_array = image.img_to_array(target_img)
img_array = tf.convert_to_tensor(img_array, dtype=tf.float32) # Convert to TensorFlow tensor
img_batch = tf.expand_dims(img_array, axis=0)
# Resize the image
resized_img = tf.image.resize(img_batch, [img_size, img_size], method=resize_method, antialias=True)
# Convert back to NumPy array
resized_img_array = resized_img.numpy()[0]
# Convert to Keras PIL image
resized_img_pil = image.array_to_img(resized_img_array)
if save_resized:
# Create new filename with suffix
dir_name = os.path.dirname("content/resized/")
base_name = os.path.basename(image_path)
name, ext = os.path.splitext(base_name)
new_name = f"{name}{suffix}.jpg" # Save as jpg
save_path = os.path.join(dir_name, new_name)
resized_img_pil.save(save_path, format="JPEG")
print(f"Saved resized image as: {save_path}")
return resized_img, resized_img_pil
except Exception as e:
print(f"Error preprocessing image {image_path}: {e}")
return None, None
def apply_temperature_scaling(logits: np.ndarray, temperature: float) -> np.ndarray:
"""Apply temperature scaling to logits
Args:
logits: Raw model outputs (logits)
temperature: Temperature parameter (T > 1 decreases confidence, T < 1 increases)
Returns:
Temperature-scaled logits
"""
if temperature <= 0:
raise ValueError("Temperature must be positive")
return logits / temperature
def add_gaussian_noise(inputs: np.ndarray, noise_std: float = 0.01) -> np.ndarray:
"""Add Gaussian noise to inputs for perturbation
Args:
inputs: Input tensor/array
noise_std: Standard deviation of Gaussian noise
Returns:
Perturbed inputs
"""
noise = np.random.normal(0, noise_std, inputs.shape)
return inputs + noise
def add_uniform_noise(inputs: np.ndarray, noise_scale: float = 0.01) -> np.ndarray:
"""Add uniform noise to inputs for perturbation
Args:
inputs: Input tensor/array
noise_scale: Scale of uniform noise [-noise_scale, noise_scale]
Returns:
Perturbed inputs
"""
noise = np.random.uniform(-noise_scale, noise_scale, inputs.shape)
return inputs + noise
def apply_dropout_perturbation(inputs: np.ndarray, dropout_rate: float = 0.1) -> np.ndarray:
"""Apply dropout-style perturbation to inputs
Args:
inputs: Input tensor/array
dropout_rate: Probability of setting elements to zero
Returns:
Perturbed inputs with dropout
"""
mask = np.random.random(inputs.shape) > dropout_rate
return inputs * mask / (1.0 - dropout_rate) # Scale to maintain expected value
def apply_mixup_perturbation(inputs: np.ndarray, alpha: float = 0.2) -> Tuple[np.ndarray, float]:
"""Apply mixup-style perturbation by mixing with a random permutation
Args:
inputs: Input tensor/array (batch_size, ...)
alpha: Beta distribution parameter for mixup
Returns:
Tuple of (mixed_inputs, lambda_value)
"""
if len(inputs.shape) < 2:
return inputs, 1.0
batch_size = inputs.shape[0]
if batch_size < 2:
return inputs, 1.0
# Sample lambda from Beta distribution
lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0
# Create random permutation
indices = np.random.permutation(batch_size)
# Mix inputs
mixed_inputs = lam * inputs + (1 - lam) * inputs[indices]
return mixed_inputs, lam
def generate_perturbations(inputs: np.ndarray, perturbation_config: Dict[str, Any]) -> List[np.ndarray]:
"""Generate multiple perturbations of the input
Args:
inputs: Input tensor/array
perturbation_config: Configuration dict with perturbation parameters
Returns:
List of perturbed inputs
"""
perturbations = [inputs] # Include original
if perturbation_config.get('gaussian_noise', {}).get('enabled', False):
noise_std = perturbation_config['gaussian_noise'].get('std', 0.01)
num_samples = perturbation_config['gaussian_noise'].get('num_samples', 3)
for _ in range(num_samples):
perturbations.append(add_gaussian_noise(inputs, noise_std))
if perturbation_config.get('uniform_noise', {}).get('enabled', False):
noise_scale = perturbation_config['uniform_noise'].get('scale', 0.01)
num_samples = perturbation_config['uniform_noise'].get('num_samples', 3)
for _ in range(num_samples):
perturbations.append(add_uniform_noise(inputs, noise_scale))
if perturbation_config.get('dropout', {}).get('enabled', False):
dropout_rate = perturbation_config['dropout'].get('rate', 0.1)
num_samples = perturbation_config['dropout'].get('num_samples', 3)
for _ in range(num_samples):
perturbations.append(apply_dropout_perturbation(inputs, dropout_rate))
if perturbation_config.get('mixup', {}).get('enabled', False):
alpha = perturbation_config['mixup'].get('alpha', 0.2)
num_samples = perturbation_config['mixup'].get('num_samples', 2)
for _ in range(num_samples):
mixed_inputs, _ = apply_mixup_perturbation(inputs, alpha)
perturbations.append(mixed_inputs)
return perturbations
def is_logits(predictions: np.ndarray, threshold: float = 10.0) -> bool:
"""
Better heuristic to determine if predictions are logits or probabilities
Args:
predictions: Model outputs
threshold: If max value > threshold, likely logits
Returns:
bool: True if likely logits, False if likely probabilities
"""
max_val = np.max(predictions)
min_val = np.min(predictions)
# Check multiple conditions
has_negative = min_val < 0
max_too_large = max_val > threshold
sum_not_one = abs(np.sum(predictions, axis=-1) - 1.0) > 0.1
return has_negative or max_too_large or np.any(sum_not_one)
def softmax_with_temperature(logits: np.ndarray, temperature: float = 1.0) -> np.ndarray:
"""
Compute softmax with temperature scaling
Args:
logits: Input logits
temperature: Temperature parameter
Returns:
Softmax probabilities
"""
# Apply temperature scaling
scaled_logits = logits / temperature
# Compute softmax with numerical stability
if isinstance(scaled_logits, tf.Tensor):
return tf.nn.softmax(scaled_logits, axis=-1)
else:
# Numerical stability: subtract max
max_logits = np.max(scaled_logits, axis=-1, keepdims=True)
exp_logits = np.exp(scaled_logits - max_logits)
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
def aggregate_predictions(predictions_list: List[np.ndarray],
aggregation_method: str = 'mean',
weights: Optional[List[float]] = None) -> np.ndarray:
"""
Aggregate predictions from multiple perturbations
Args:
predictions_list: List of prediction arrays
aggregation_method: 'mean', 'weighted_mean', 'median', or 'vote'
weights: Weights for weighted aggregation
Returns:
Aggregated predictions
"""
if not predictions_list:
raise ValueError("Empty predictions list")
if len(predictions_list) == 1:
return predictions_list[0]
predictions_array = np.array(predictions_list)
if aggregation_method == 'mean':
return np.mean(predictions_array, axis=0)
elif aggregation_method == 'weighted_mean':
if weights is None:
weights = [1.0] * len(predictions_list)
weights = np.array(weights) / np.sum(weights)
return np.average(predictions_array, axis=0, weights=weights)
elif aggregation_method == 'median':
return np.median(predictions_array, axis=0)
elif aggregation_method == 'vote':
# Majority voting based on argmax
votes = np.argmax(predictions_array, axis=-1)
result = np.zeros_like(predictions_array[0])
for i in range(result.shape[0]):
if len(result.shape) > 1:
for j in range(result.shape[1]):
unique, counts = np.unique(votes[:, i], return_counts=True)
majority_class = unique[np.argmax(counts)]
result[i, majority_class] = 1.0
else:
unique, counts = np.unique(votes, return_counts=True)
majority_class = unique[np.argmax(counts)]
result[majority_class] = 1.0
return result
else:
raise ValueError(f"Unknown aggregation method: {aggregation_method}")
def compute_prediction_uncertainty(predictions_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute prediction uncertainty metrics from multiple perturbations
Args:
predictions_list: List of prediction arrays
Returns:
Tuple of (predictive_entropy, prediction_variance)
"""
if len(predictions_list) <= 1:
return np.zeros(predictions_list[0].shape[:-1]), np.zeros(predictions_list[0].shape[:-1])
predictions_array = np.array(predictions_list)
# Predictive entropy (aleatoric uncertainty)
mean_predictions = np.mean(predictions_array, axis=0)
predictive_entropy = -np.sum(mean_predictions * np.log(mean_predictions + 1e-10), axis=-1)
# Prediction variance (epistemic uncertainty)
prediction_variance = np.var(predictions_array, axis=0)
prediction_variance = np.mean(prediction_variance, axis=-1)
return predictive_entropy, prediction_variance
# Add energy-based detection function
def compute_energy_score(logits: np.ndarray) -> np.ndarray:
"""
Compute energy score for OOD detection: E(x) = -T * log(sum(exp(f(x)/T)))
Lower energy for in-distribution samples, higher energy for OOD samples
"""
if not is_logits(logits):
warnings.warn("Energy-based detection requires logits, not probabilities!")
# Rough approximation of logits from probabilities
eps = 1e-10
logits = np.log(np.clip(logits, eps, 1.0))
if isinstance(logits, tf.Tensor):
energy = -tf.reduce_logsumexp(logits, axis=-1)
else:
try:
energy = -logsumexp(logits, axis=-1)
except ImportError:
max_logits = np.max(logits, axis=-1, keepdims=True)
exp_logits = np.exp(logits - max_logits)
sum_exp = np.sum(exp_logits, axis=-1)
energy = -(np.log(sum_exp) + np.squeeze(max_logits))
return energy
def detect_ood_with_perturbation(prediction_list_logits: List[np.ndarray], predictions_list: List[np.ndarray],
ood_config: Dict[str, Any]) -> Tuple[bool, str, np.ndarray, Dict[str, float]]:
"""
Enhanced OOD detection using prediction uncertainty from perturbations
Args:
predictions_list: List of predictions from perturbations
ood_config: OOD detection configuration
Returns:
Tuple of (is_ood, reason, calibrated_probs, uncertainty_metrics)
"""
if not ood_config.get('enabled', False):
return False, "OOD detection disabled", predictions_list[0], {}
# Aggregate predictions
aggregation_method = ood_config.get('aggregation_method', 'mean')
aggregated_probs = aggregate_predictions(predictions_list, aggregation_method)
# Compute uncertainty metrics
predictive_entropy, prediction_variance = compute_prediction_uncertainty(predictions_list)
# Traditional confidence-based detection
confidence_threshold = ood_config.get('confidence_threshold', 0.5)
max_confidence = np.max(aggregated_probs, axis=-1)
if max_confidence.ndim == 0:
max_confidence = float(max_confidence)
else:
max_confidence = float(max_confidence[0])
# Uncertainty-based detection
entropy_threshold = ood_config.get('entropy_threshold', 1.0)
variance_threshold = ood_config.get('variance_threshold', 0.1)
if predictive_entropy.ndim == 0:
entropy_value = float(predictive_entropy)
variance_value = float(prediction_variance)
else:
entropy_value = float(predictive_entropy[0])
variance_value = float(prediction_variance[0])
# Energy-based detection
energy_value = None
energy_ood = False
if ood_config.get('use_energy', False):
energy_value = float(compute_energy_score(prediction_list_logits[0])[0]
if compute_energy_score(prediction_list_logits[0]).ndim > 0
else compute_energy_score(prediction_list_logits[0]))
energy_threshold = ood_config.get('energy_threshold')
energy_ood = energy_value > energy_threshold
# Determine OOD status
confidence_ood = max_confidence < confidence_threshold
entropy_ood = entropy_value > entropy_threshold
variance_ood = variance_value > variance_threshold
# Combine criteria
use_uncertainty = ood_config.get('use_uncertainty', True)
use_energy = ood_config.get('use_energy', True)
if use_uncertainty and use_energy:
is_ood = confidence_ood or entropy_ood or variance_ood or energy_ood
reasons = []
if confidence_ood:
reasons.append(f"Low confidence: {max_confidence:.4f} < {confidence_threshold}")
if entropy_ood:
reasons.append(f"High entropy: {entropy_value:.4f} > {entropy_threshold}")
if variance_ood:
reasons.append(f"High variance: {variance_value:.4f} > {variance_threshold}")
if energy_ood:
reasons.append(f"High energy: {energy_value:.4f} > {energy_threshold}")
if reasons:
reason = "; ".join(reasons)
else:
reason = f"In-distribution: conf={max_confidence:.4f}, entropy={entropy_value:.4f}, var={variance_value:.4f}, energy={energy_value:.4f}"
elif use_energy:
is_ood = energy_ood or confidence_ood
if is_ood:
if energy_ood:
reason = f"High energy: {energy_value:.4f} > {energy_threshold}"
else:
reason = f"Low confidence: {max_confidence:.4f} < {confidence_threshold}"
else:
reason = f"Low energy: {energy_value:.4f} <= {energy_threshold}"
elif use_uncertainty:
is_ood = confidence_ood or entropy_ood or variance_ood
reasons = []
if confidence_ood:
reasons.append(f"Low confidence: {max_confidence:.4f} < {confidence_threshold}")
if entropy_ood:
reasons.append(f"High entropy: {entropy_value:.4f} > {entropy_threshold}")
if variance_ood:
reasons.append(f"High variance: {variance_value:.4f} > {variance_threshold}")
if reasons:
reason = "; ".join(reasons)
else:
reason = f"In-distribution: conf={max_confidence:.4f}, entropy={entropy_value:.4f}, var={variance_value:.4f}"
else:
is_ood = confidence_ood
if is_ood:
reason = f"Low confidence: {max_confidence:.4f} < {confidence_threshold}"
else:
reason = f"High confidence: {max_confidence:.4f} >= {confidence_threshold}"
uncertainty_metrics = {
'predictive_entropy': entropy_value,
'prediction_variance': variance_value,
'max_confidence': max_confidence,
'num_perturbations': len(predictions_list)
}
if energy_value is not None:
uncertainty_metrics['energy_score'] = energy_value
return is_ood, reason, aggregated_probs, uncertainty_metrics
def predict_single_image_with_perturbation(model, img_batch: np.ndarray,
class_names: List[str],
model_type: str,
ood_config: Optional[Dict[str, Any]] = DEFAULT_OOD_CONFIG,
perturbation_config: Optional[Dict[str, Any]] = None) -> Tuple[str, float, bool, str, Dict[str, float]]:
"""
Make prediction on a single preprocessed image with perturbation-based uncertainty estimation
Args:
model: Trained model
img_batch: Preprocessed image batch
class_names: List of class names
model_type: "keras" or "savedmodel"
ood_config: OOD detection configuration
perturbation_config: Input perturbation configuration
Returns:
Tuple of (predicted_class, confidence, ood_detected, ood_reason, uncertainty_metrics)
"""
try:
# Default configurations
if ood_config is None:
ood_config = {'enabled': False}
if perturbation_config is None:
perturbation_config = {'gaussian_noise': {'enabled': False}}
# Generate perturbations
perturbations = generate_perturbations(img_batch, perturbation_config)
# Get predictions for all perturbations
all_predictions = []
all_predictions_logits = []
for perturbed_input in perturbations:
# Convert to tensor if needed
if model_type == "keras":
predictions = model(perturbed_input, training=False)
else: # SavedModel
if hasattr(model, 'signatures'):
if 'serving_default' in model.signatures:
serving_fn = model.signatures['serving_default']
else:
signature_key = list(model.signatures.keys())[0]
serving_fn = model.signatures[signature_key]
input_tensor = tf.convert_to_tensor(perturbed_input, dtype=tf.float32)
input_signature = serving_fn.structured_input_signature[1]
input_key = list(input_signature.keys())[0]
prediction_dict = serving_fn(**{input_key: input_tensor})
if isinstance(prediction_dict, dict):
predictions = list(prediction_dict.values())[0]
else:
predictions = prediction_dict
else:
predictions = model(tf.convert_to_tensor(perturbed_input, dtype=tf.float32))
# Convert to numpy and ensure proper shape
if hasattr(predictions, 'numpy'):
predictions_np = predictions.numpy()
else:
predictions_np = np.array(predictions)
if predictions_np.ndim == 1:
predictions_np = predictions_np.reshape(1, -1)
all_predictions_logits.append(predictions_np)
# Apply temperature scaling if configured and input is logits
temperature = ood_config.get('temperature', 1.0)
if is_logits(predictions_np) and temperature != 1.0:
predictions_np = softmax_with_temperature(predictions_np, temperature)
elif not is_logits(predictions_np) and temperature != 1.0:
warnings.warn("Cannot apply temperature scaling to probabilities. Need logits!")
elif is_logits(predictions_np):
predictions_np = softmax_with_temperature(predictions_np, 1.0)
all_predictions.append(predictions_np)
# Enhanced OOD detection with uncertainty
ood_detected, ood_reason, final_probs, uncertainty_metrics = detect_ood_with_perturbation(all_predictions_logits,
all_predictions, ood_config
)
# Get final prediction
predicted_class_index = np.argmax(final_probs, axis=-1)[0]
final_confidence = np.max(final_probs, axis=-1)[0]
# Get class name
if class_names and predicted_class_index < len(class_names):
predicted_class = class_names[predicted_class_index]
else:
predicted_class = f"Class_{predicted_class_index}"
return predicted_class, float(final_confidence), ood_detected, ood_reason, uncertainty_metrics
except Exception as e:
print(f"Error during prediction: {e}")
print(f"Model type: {type(model)}")
if hasattr(model, 'signatures'):
print(f"Available signatures: {list(model.signatures.keys())}")
for sig_name, sig in model.signatures.items():
print(f"Signature '{sig_name}':")
print(f" Inputs: {sig.structured_input_signature}")
print(f" Outputs: {sig.structured_outputs}")
return None, None, False, "Prediction error", {}
def create_results_visualization(results, config, ood_config, output_filename=None):
"""Create a visualization table of prediction results"""
if not results:
print("No results to visualize!")
return None
n_images = len(results)
cols = min(6, n_images) # Max 4 columns
rows = math.ceil(n_images / cols)
fig_width = cols * 6
fig_height = rows * 5
fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
# Handle single row/column cases
if rows == 1 and cols == 1:
axes = [axes]
elif rows == 1:
axes = [axes]
elif cols == 1:
axes = [[ax] for ax in axes]
else:
# axes is already 2D
pass
if rows > 1:
axes_flat = [ax for row in axes for ax in row]
else:
axes_flat = axes if isinstance(axes, list) else [axes]
for i, (image_path, img_pil, pred_class, confidence, is_ood) in enumerate(results):
if i >= len(axes_flat):
break
ax = axes_flat[i]
ax.imshow(img_pil)
ax.axis('off')
filename = os.path.basename(image_path)
confidence_pct = confidence * 100
if confidence >= 0.9:
title_color = 'green'
status = 'High Conf'
elif confidence >= 0.7:
title_color = 'orange'
status = 'Med Conf'
else:
title_color = 'red'
status = 'Low Conf'
if is_ood:
status = 'OOD'
title = f"{pred_class}\n{confidence_pct:.1f}% ({status})"
ax.set_title(title, fontsize=10, color=title_color, weight='bold')
# Add border color based on status
for spine in ax.spines.values():
spine.set_edgecolor(title_color)
spine.set_linewidth(3)
# Hide unused subplots
for i in range(len(results), len(axes_flat)):
axes_flat[i].axis('off')
# Add overall title
model_name = config['model_name'].upper()
img_size = config['img_size']
total_images = len(results)
ood_count = sum(1 for _, _, _, _, is_ood in results if is_ood)
avg_confidence = np.mean([conf for _, _, _, conf, _ in results]) * 100
title_text = f"{model_name} Model Results (Image Size: {img_size}x{img_size})\n"
title_text += f"Total Images: {total_images} | OOD Detected: {ood_count} | Avg Confidence: {avg_confidence:.1f}%"
if ood_config['enabled']:
title_text += f"\nOOD Detection: ON (Temp: {ood_config['temperature']}, Threshold: {ood_config['threshold']})"
else:
title_text += "\nOOD Detection: OFF"
fig.suptitle(title_text, fontsize=14, weight='bold', y=0.98)
# Add legend
legend_elements = [
patches.Patch(color='green', label='High Confidence (≥90%)'),
patches.Patch(color='orange', label='Medium Confidence (70-89%)'),
patches.Patch(color='red', label='Low Confidence (<70%) / OOD')
]
fig.legend(handles=legend_elements, loc='lower center', ncol=3, bbox_to_anchor=(0.5, 0.02))
# plt.tight_layout()
plt.subplots_adjust(top=0.9, bottom=0.1)
if output_filename is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"prediction_results_{model_name.lower()}_{timestamp}.png"
plt.savefig(output_filename, dpi=300, bbox_inches='tight',
facecolor='white', edgecolor='none')
print(f"Results visualization saved as: {output_filename}")
# plt.show()
return output_filename
def generate_prediction_report():
"""Generate prediction reports for each image folder in a base directory or a single image folder."""
print("=" * 60)
print("PREDICTION RESULTS VISUALIZATION GENERATOR")
print("=" * 60)
# Select model
config = select_model_for_testing()
if config is None:
return
# Configure OOD detection
print("\nConfiguring OOD Detection...")
enable_str = input(f"Enable OOD detection? (y/n, default: {'y' if ENABLE_OOD_DETECTION else 'n'}): ").strip().lower()
if enable_str == '':
enable_ood = ENABLE_OOD_DETECTION
else:
enable_ood = enable_str in ['y', 'yes', '1', 'true']
ood_config = {'enabled': enable_ood}
if enable_ood:
temp_str = input(f"Temperature parameter (default: {DEFAULT_TEMPERATURE}): ").strip()
try:
temperature = float(temp_str) if temp_str else DEFAULT_TEMPERATURE
except ValueError:
temperature = DEFAULT_TEMPERATURE
thresh_str = input(f"Confidence threshold (default: {DEFAULT_CONFIDENCE_THRESHOLD}): ").strip()
try:
threshold = float(thresh_str) if thresh_str else DEFAULT_CONFIDENCE_THRESHOLD
except ValueError:
threshold = DEFAULT_CONFIDENCE_THRESHOLD
# Add energy-based OOD detection configuration
enable_energy_str = input(f"Enable energy-based OOD detection? (y/n, default: {'y' if ood_config['use_energy'] else 'n'}): ").strip().lower()
if enable_energy_str == '':
enable_energy = ood_config['use_energy']
else:
enable_energy = enable_energy_str in ['y', 'yes', '1', 'true']
if enable_energy:
energy_thresh_str = input(f"Energy threshold (default: {ood_config['energy_threshold']}, higher means stricter OOD): ").strip()
try:
energy_threshold = float(energy_thresh_str) if energy_thresh_str else ood_config['energy_threshold']
except ValueError:
energy_threshold = 10.0
ood_config.update({
'use_energy': True,
'energy_threshold': energy_threshold
})
ood_config.update({
'temperature': temperature,
'confidence_threshold': threshold
})
# Load model
model, class_names, model_type, config = load_model_and_classes(config)
if model is None:
print("Failed to load model!")
return
# Prompt for base directory or image folder
base_path = input("Enter the base directory path or a single image folder: ").strip()
base_path = TEST_DIR if base_path == "" or base_path is None else base_path
if not os.path.isdir(base_path):
print(f"Directory not found: {base_path}")
return
img_size = config['img_size']
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
max_images = 100
# Check if base_path itself contains images
image_paths = []
for ext in image_extensions:
image_paths.extend(glob.glob(os.path.join(base_path, ext)))
image_paths = sorted(image_paths, key=lambda x: os.path.basename(x).lower())
if image_paths: # base_path is an image folder
print(f"\nDetected {len(image_paths)} images in selected folder: {os.path.basename(base_path)}")
folder_name = os.path.basename(base_path.rstrip('/\\'))
if len(image_paths) > max_images:
print(f"Limiting to first {max_images} images for visualization")
image_paths = image_paths[:max_images]
results = []
print("Processing images...")
for i, image_path in enumerate(image_paths):
print(f"Processing {i+1}/{len(image_paths)}: {os.path.basename(image_path)}")
img_batch, img_pil = preprocess_image(image_path, img_size, 'bicubic')
if img_batch is None:
continue
pred_class, confidence, ood_detected, ood_reason, uncertainty_metrics = predict_single_image_with_perturbation(
model, img_batch, class_names, model_type, ood_config, perturbation_config=DEFAULT_PERTURBATION_CONFIG
)
if pred_class is not None:
results.append((image_path, img_pil, pred_class, confidence, ood_detected))
if not results:
print(f"No successful predictions in {folder_name} to visualize!")
return
output_file = f"{folder_name}_results.png"
output_file = os.path.join("visualization", output_file)
print(f"\nGenerating visualization for {len(results)} predictions in {folder_name}...")
create_results_visualization(results, config, ood_config, output_filename=output_file)
print(f"Saved visualization: {output_file}")
# Print summary statistics
print("\n" + "=" * 60)
print(f"SUMMARY STATISTICS for {folder_name}")
print("=" * 60)
confidences = [conf for _, _, _, conf, _ in results]
ood_detections = sum(1 for _, _, _, _, is_ood in results if is_ood)
print(f"Total predictions: {len(results)}")
print(f"Average confidence: {np.mean(confidences)*100:.1f}%")
print(f"Highest confidence: {np.max(confidences)*100:.1f}%")
print(f"Lowest confidence: {np.min(confidences)*100:.1f}%")
print(f"OOD detections: {ood_detections}/{len(results)} ({100*ood_detections/len(results):.1f}%)")
classes = [pred_class for _, _, pred_class, _, _ in results]
unique_classes = list(set(classes))
print(f"\nPredicted classes:")
for cls in unique_classes:
count = classes.count(cls)
print(f" {cls}: {count} ({100*count/len(results):.1f}%)")
return
# Otherwise, treat as base directory containing subfolders
subfolders = [os.path.join(base_path, d) for d in os.listdir(base_path)
if os.path.isdir(os.path.join(base_path, d))]
if not subfolders:
print(f"No subfolders found in {base_path}")
return
print(f"Found {len(subfolders)} image folders to process")
for folder_path in subfolders:
folder_name = os.path.basename(folder_path.rstrip('/\\'))
image_paths = []
for ext in image_extensions:
image_paths.extend(glob.glob(os.path.join(folder_path, ext)))
image_paths = sorted(image_paths, key=lambda x: os.path.basename(x).lower())
if not image_paths:
print(f"No images found in folder {folder_name}")
continue
print(f"\nProcessing folder: {folder_name}")
print(f"Found {len(image_paths)} images in folder {folder_name}")
if len(image_paths) > max_images:
print(f"Limiting to first {max_images} images for visualization")
image_paths = image_paths[:max_images]
results = []
print("Processing images...")
for i, image_path in enumerate(image_paths):
print(f"Processing {i+1}/{len(image_paths)}: {os.path.basename(image_path)}")
img_batch, img_pil = preprocess_image(image_path, img_size, 'bicubic')
if img_batch is None:
continue
pred_class, confidence, ood_detected, ood_reason, uncertainty_metrics = predict_single_image_with_perturbation(
model, img_batch, class_names, model_type, ood_config, perturbation_config=DEFAULT_PERTURBATION_CONFIG
)
if pred_class is not None:
results.append((image_path, img_pil, pred_class, confidence, ood_detected))
if not results:
print(f"No successful predictions in {folder_name} to visualize!")
continue
output_file = f"{folder_name}_results.png"
output_file = os.path.join("visualization", output_file)
print(f"\nGenerating visualization for {len(results)} predictions in {folder_name}...")
create_results_visualization(results, config, ood_config, output_filename=output_file)
print(f"Saved visualization: {output_file}")
# Print summary statistics for this folder
print("\n" + "=" * 60)
print(f"SUMMARY STATISTICS for {folder_name}")
print("=" * 60)
confidences = [conf for _, _, _, conf, _ in results]
ood_detections = sum(1 for _, _, _, _, is_ood in results if is_ood)
print(f"Total predictions: {len(results)}")
print(f"Average confidence: {np.mean(confidences)*100:.1f}%")
print(f"Highest confidence: {np.max(confidences)*100:.1f}%")
print(f"Lowest confidence: {np.min(confidences)*100:.1f}%")
print(f"OOD detections: {ood_detections}/{len(results)} ({100*ood_detections/len(results):.1f}%)")
classes = [pred_class for _, _, pred_class, _, _ in results]
unique_classes = list(set(classes))
print(f"\nPredicted classes:")
for cls in unique_classes:
count = classes.count(cls)
print(f" {cls}: {count} ({100*count/len(results):.1f}%)")
print("\nAll folders processed.")
if __name__ == "__main__":
generate_prediction_report()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment