Last active
July 17, 2025 13:41
-
-
Save omeganoob/d5d052acd8869791e3dead893f519bed to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import glob | |
import json | |
import math | |
import os | |
import warnings | |
from datetime import datetime | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import matplotlib.patches as patches | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import tensorflow as tf | |
from scipy.special import logsumexp | |
from tensorflow.keras.preprocessing import image | |
# Test directory | |
TEST_DIR = "content/test" | |
# OOD Detection Configuration | |
DEFAULT_TEMPERATURE = 1 | |
DEFAULT_CONFIDENCE_THRESHOLD = 0.95 | |
ENABLE_OOD_DETECTION = True | |
DEFAULT_PERTURBATION_CONFIG = { | |
'gaussian_noise': { | |
'enabled': True, | |
'std': 0.005, # reduced | |
'num_samples': 10 # increased | |
}, | |
'uniform_noise': { | |
'enabled': False, | |
'scale': 0.005, | |
'num_samples': 10 | |
}, | |
'dropout': { | |
'enabled': True, | |
'rate': 0.03, # reduced | |
'num_samples': 10 | |
}, | |
'mixup': { | |
'enabled': False, | |
'alpha': 0.1, | |
'num_samples': 2 | |
} | |
} | |
DEFAULT_OOD_CONFIG = { | |
'enabled': True, | |
'temperature': 1.0, | |
'confidence_threshold': 0.5, | |
'entropy_threshold': 1.0, | |
'variance_threshold': 0.1, | |
'use_uncertainty': True, | |
'use_energy': True, | |
'energy_threshold': 10.0, | |
'aggregation_method': 'mean' | |
} | |
def discover_trained_models(): | |
"""Discover all trained models in the content directory""" | |
config_files = glob.glob("./content/model_config_*.json") | |
models = [] | |
for config_file in config_files: | |
try: | |
with open(config_file, 'r') as f: | |
config = json.load(f) | |
# Check if model files exist | |
model_exists = ( | |
os.path.exists(config['model_export_dir'] + ".keras") or | |
os.path.exists(config['model_export_dir'] + "_keras") or | |
os.path.exists(config['model_export_dir']) | |
) | |
if model_exists: | |
models.append({ | |
'config_file': config_file, | |
'config': config, | |
'display_name': f"{config['model_name'].upper()} (img_size: {config['img_size']})" | |
}) | |
except Exception as e: | |
print(f"Error reading config file {config_file}: {e}") | |
continue | |
return models | |
def select_model_for_testing(): | |
"""Interactive model selection for testing""" | |
models = discover_trained_models() | |
if not models: | |
print("No trained models found!") | |
print("Please run the training script first.") | |
return None | |
print("Available trained models:") | |
print("=" * 50) | |
for i, model_info in enumerate(models, 1): | |
print(f"{i:2d}. {model_info['display_name']}") | |
print(f" Directory suffix: {model_info['config']['dir_suffix']}") | |
print("=" * 50) | |
while True: | |
try: | |
choice = input(f"Select model to test (1-{len(models)}): ").strip() | |
if choice.isdigit(): | |
choice_idx = int(choice) - 1 | |
if 0 <= choice_idx < len(models): | |
selected_model = models[choice_idx] | |
break | |
else: | |
print(f"Please enter a number between 1 and {len(models)}") | |
continue | |
print("Please enter a valid number.") | |
except KeyboardInterrupt: | |
print("\nOperation cancelled.") | |
return None | |
except Exception as e: | |
print(f"Error: {e}") | |
continue | |
config = selected_model['config'] | |
print(f"\nSelected: {config['model_name'].upper()}") | |
print(f"Image Size: {config['img_size']}") | |
print(f"Directory Suffix: {config['dir_suffix']}") | |
return config | |
def load_model_and_classes(config): | |
"""Load the trained model and class names based on configuration""" | |
model = None | |
model_type = None | |
model_export_dir = config['model_export_dir'] | |
class_names_file = config['class_names_file'] | |
# Try different model formats in order of preference | |
model_paths_to_try = [ | |
(model_export_dir + ".keras", "keras", "Keras .keras file"), | |
(model_export_dir + "_keras", "keras", "Keras directory format"), | |
(model_export_dir, "savedmodel", "SavedModel format") | |
] | |
for model_path, m_type, description in model_paths_to_try: | |
try: | |
print(f"Trying to load {description} from {model_path}...") | |
if m_type == "keras": | |
model = tf.keras.models.load_model(model_path) | |
else: # savedmodel | |
model = tf.saved_model.load(model_path) | |
model_type = m_type | |
print(f"{description} loaded successfully!") | |
break | |
except Exception as e: | |
print(f"Failed to load {description}: {e}") | |
continue | |
if model is None: | |
print("Failed to load any model format!") | |
return None, None, None, None | |
# Load class names | |
try: | |
with open(class_names_file, 'r') as f: | |
class_names = json.load(f) | |
print(f"Class names loaded: {class_names}") | |
return model, class_names, model_type, config | |
except Exception as e: | |
print(f"Error loading class names: {e}") | |
return model, None, model_type, config | |
def preprocess_image(image_path, img_size, suffix='_resized', resize_method='bilinear', save_resized=False): | |
try: | |
target_img = image.load_img(image_path) # Remove target_size | |
img_array = image.img_to_array(target_img) | |
img_array = tf.convert_to_tensor(img_array, dtype=tf.float32) # Convert to TensorFlow tensor | |
img_batch = tf.expand_dims(img_array, axis=0) | |
# Resize the image | |
resized_img = tf.image.resize(img_batch, [img_size, img_size], method=resize_method, antialias=True) | |
# Convert back to NumPy array | |
resized_img_array = resized_img.numpy()[0] | |
# Convert to Keras PIL image | |
resized_img_pil = image.array_to_img(resized_img_array) | |
if save_resized: | |
# Create new filename with suffix | |
dir_name = os.path.dirname("content/resized/") | |
base_name = os.path.basename(image_path) | |
name, ext = os.path.splitext(base_name) | |
new_name = f"{name}{suffix}.jpg" # Save as jpg | |
save_path = os.path.join(dir_name, new_name) | |
resized_img_pil.save(save_path, format="JPEG") | |
print(f"Saved resized image as: {save_path}") | |
return resized_img, resized_img_pil | |
except Exception as e: | |
print(f"Error preprocessing image {image_path}: {e}") | |
return None, None | |
def apply_temperature_scaling(logits: np.ndarray, temperature: float) -> np.ndarray: | |
"""Apply temperature scaling to logits | |
Args: | |
logits: Raw model outputs (logits) | |
temperature: Temperature parameter (T > 1 decreases confidence, T < 1 increases) | |
Returns: | |
Temperature-scaled logits | |
""" | |
if temperature <= 0: | |
raise ValueError("Temperature must be positive") | |
return logits / temperature | |
def add_gaussian_noise(inputs: np.ndarray, noise_std: float = 0.01) -> np.ndarray: | |
"""Add Gaussian noise to inputs for perturbation | |
Args: | |
inputs: Input tensor/array | |
noise_std: Standard deviation of Gaussian noise | |
Returns: | |
Perturbed inputs | |
""" | |
noise = np.random.normal(0, noise_std, inputs.shape) | |
return inputs + noise | |
def add_uniform_noise(inputs: np.ndarray, noise_scale: float = 0.01) -> np.ndarray: | |
"""Add uniform noise to inputs for perturbation | |
Args: | |
inputs: Input tensor/array | |
noise_scale: Scale of uniform noise [-noise_scale, noise_scale] | |
Returns: | |
Perturbed inputs | |
""" | |
noise = np.random.uniform(-noise_scale, noise_scale, inputs.shape) | |
return inputs + noise | |
def apply_dropout_perturbation(inputs: np.ndarray, dropout_rate: float = 0.1) -> np.ndarray: | |
"""Apply dropout-style perturbation to inputs | |
Args: | |
inputs: Input tensor/array | |
dropout_rate: Probability of setting elements to zero | |
Returns: | |
Perturbed inputs with dropout | |
""" | |
mask = np.random.random(inputs.shape) > dropout_rate | |
return inputs * mask / (1.0 - dropout_rate) # Scale to maintain expected value | |
def apply_mixup_perturbation(inputs: np.ndarray, alpha: float = 0.2) -> Tuple[np.ndarray, float]: | |
"""Apply mixup-style perturbation by mixing with a random permutation | |
Args: | |
inputs: Input tensor/array (batch_size, ...) | |
alpha: Beta distribution parameter for mixup | |
Returns: | |
Tuple of (mixed_inputs, lambda_value) | |
""" | |
if len(inputs.shape) < 2: | |
return inputs, 1.0 | |
batch_size = inputs.shape[0] | |
if batch_size < 2: | |
return inputs, 1.0 | |
# Sample lambda from Beta distribution | |
lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0 | |
# Create random permutation | |
indices = np.random.permutation(batch_size) | |
# Mix inputs | |
mixed_inputs = lam * inputs + (1 - lam) * inputs[indices] | |
return mixed_inputs, lam | |
def generate_perturbations(inputs: np.ndarray, perturbation_config: Dict[str, Any]) -> List[np.ndarray]: | |
"""Generate multiple perturbations of the input | |
Args: | |
inputs: Input tensor/array | |
perturbation_config: Configuration dict with perturbation parameters | |
Returns: | |
List of perturbed inputs | |
""" | |
perturbations = [inputs] # Include original | |
if perturbation_config.get('gaussian_noise', {}).get('enabled', False): | |
noise_std = perturbation_config['gaussian_noise'].get('std', 0.01) | |
num_samples = perturbation_config['gaussian_noise'].get('num_samples', 3) | |
for _ in range(num_samples): | |
perturbations.append(add_gaussian_noise(inputs, noise_std)) | |
if perturbation_config.get('uniform_noise', {}).get('enabled', False): | |
noise_scale = perturbation_config['uniform_noise'].get('scale', 0.01) | |
num_samples = perturbation_config['uniform_noise'].get('num_samples', 3) | |
for _ in range(num_samples): | |
perturbations.append(add_uniform_noise(inputs, noise_scale)) | |
if perturbation_config.get('dropout', {}).get('enabled', False): | |
dropout_rate = perturbation_config['dropout'].get('rate', 0.1) | |
num_samples = perturbation_config['dropout'].get('num_samples', 3) | |
for _ in range(num_samples): | |
perturbations.append(apply_dropout_perturbation(inputs, dropout_rate)) | |
if perturbation_config.get('mixup', {}).get('enabled', False): | |
alpha = perturbation_config['mixup'].get('alpha', 0.2) | |
num_samples = perturbation_config['mixup'].get('num_samples', 2) | |
for _ in range(num_samples): | |
mixed_inputs, _ = apply_mixup_perturbation(inputs, alpha) | |
perturbations.append(mixed_inputs) | |
return perturbations | |
def is_logits(predictions: np.ndarray, threshold: float = 10.0) -> bool: | |
""" | |
Better heuristic to determine if predictions are logits or probabilities | |
Args: | |
predictions: Model outputs | |
threshold: If max value > threshold, likely logits | |
Returns: | |
bool: True if likely logits, False if likely probabilities | |
""" | |
max_val = np.max(predictions) | |
min_val = np.min(predictions) | |
# Check multiple conditions | |
has_negative = min_val < 0 | |
max_too_large = max_val > threshold | |
sum_not_one = abs(np.sum(predictions, axis=-1) - 1.0) > 0.1 | |
return has_negative or max_too_large or np.any(sum_not_one) | |
def softmax_with_temperature(logits: np.ndarray, temperature: float = 1.0) -> np.ndarray: | |
""" | |
Compute softmax with temperature scaling | |
Args: | |
logits: Input logits | |
temperature: Temperature parameter | |
Returns: | |
Softmax probabilities | |
""" | |
# Apply temperature scaling | |
scaled_logits = logits / temperature | |
# Compute softmax with numerical stability | |
if isinstance(scaled_logits, tf.Tensor): | |
return tf.nn.softmax(scaled_logits, axis=-1) | |
else: | |
# Numerical stability: subtract max | |
max_logits = np.max(scaled_logits, axis=-1, keepdims=True) | |
exp_logits = np.exp(scaled_logits - max_logits) | |
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) | |
def aggregate_predictions(predictions_list: List[np.ndarray], | |
aggregation_method: str = 'mean', | |
weights: Optional[List[float]] = None) -> np.ndarray: | |
""" | |
Aggregate predictions from multiple perturbations | |
Args: | |
predictions_list: List of prediction arrays | |
aggregation_method: 'mean', 'weighted_mean', 'median', or 'vote' | |
weights: Weights for weighted aggregation | |
Returns: | |
Aggregated predictions | |
""" | |
if not predictions_list: | |
raise ValueError("Empty predictions list") | |
if len(predictions_list) == 1: | |
return predictions_list[0] | |
predictions_array = np.array(predictions_list) | |
if aggregation_method == 'mean': | |
return np.mean(predictions_array, axis=0) | |
elif aggregation_method == 'weighted_mean': | |
if weights is None: | |
weights = [1.0] * len(predictions_list) | |
weights = np.array(weights) / np.sum(weights) | |
return np.average(predictions_array, axis=0, weights=weights) | |
elif aggregation_method == 'median': | |
return np.median(predictions_array, axis=0) | |
elif aggregation_method == 'vote': | |
# Majority voting based on argmax | |
votes = np.argmax(predictions_array, axis=-1) | |
result = np.zeros_like(predictions_array[0]) | |
for i in range(result.shape[0]): | |
if len(result.shape) > 1: | |
for j in range(result.shape[1]): | |
unique, counts = np.unique(votes[:, i], return_counts=True) | |
majority_class = unique[np.argmax(counts)] | |
result[i, majority_class] = 1.0 | |
else: | |
unique, counts = np.unique(votes, return_counts=True) | |
majority_class = unique[np.argmax(counts)] | |
result[majority_class] = 1.0 | |
return result | |
else: | |
raise ValueError(f"Unknown aggregation method: {aggregation_method}") | |
def compute_prediction_uncertainty(predictions_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: | |
""" | |
Compute prediction uncertainty metrics from multiple perturbations | |
Args: | |
predictions_list: List of prediction arrays | |
Returns: | |
Tuple of (predictive_entropy, prediction_variance) | |
""" | |
if len(predictions_list) <= 1: | |
return np.zeros(predictions_list[0].shape[:-1]), np.zeros(predictions_list[0].shape[:-1]) | |
predictions_array = np.array(predictions_list) | |
# Predictive entropy (aleatoric uncertainty) | |
mean_predictions = np.mean(predictions_array, axis=0) | |
predictive_entropy = -np.sum(mean_predictions * np.log(mean_predictions + 1e-10), axis=-1) | |
# Prediction variance (epistemic uncertainty) | |
prediction_variance = np.var(predictions_array, axis=0) | |
prediction_variance = np.mean(prediction_variance, axis=-1) | |
return predictive_entropy, prediction_variance | |
# Add energy-based detection function | |
def compute_energy_score(logits: np.ndarray) -> np.ndarray: | |
""" | |
Compute energy score for OOD detection: E(x) = -T * log(sum(exp(f(x)/T))) | |
Lower energy for in-distribution samples, higher energy for OOD samples | |
""" | |
if not is_logits(logits): | |
warnings.warn("Energy-based detection requires logits, not probabilities!") | |
# Rough approximation of logits from probabilities | |
eps = 1e-10 | |
logits = np.log(np.clip(logits, eps, 1.0)) | |
if isinstance(logits, tf.Tensor): | |
energy = -tf.reduce_logsumexp(logits, axis=-1) | |
else: | |
try: | |
energy = -logsumexp(logits, axis=-1) | |
except ImportError: | |
max_logits = np.max(logits, axis=-1, keepdims=True) | |
exp_logits = np.exp(logits - max_logits) | |
sum_exp = np.sum(exp_logits, axis=-1) | |
energy = -(np.log(sum_exp) + np.squeeze(max_logits)) | |
return energy | |
def detect_ood_with_perturbation(prediction_list_logits: List[np.ndarray], predictions_list: List[np.ndarray], | |
ood_config: Dict[str, Any]) -> Tuple[bool, str, np.ndarray, Dict[str, float]]: | |
""" | |
Enhanced OOD detection using prediction uncertainty from perturbations | |
Args: | |
predictions_list: List of predictions from perturbations | |
ood_config: OOD detection configuration | |
Returns: | |
Tuple of (is_ood, reason, calibrated_probs, uncertainty_metrics) | |
""" | |
if not ood_config.get('enabled', False): | |
return False, "OOD detection disabled", predictions_list[0], {} | |
# Aggregate predictions | |
aggregation_method = ood_config.get('aggregation_method', 'mean') | |
aggregated_probs = aggregate_predictions(predictions_list, aggregation_method) | |
# Compute uncertainty metrics | |
predictive_entropy, prediction_variance = compute_prediction_uncertainty(predictions_list) | |
# Traditional confidence-based detection | |
confidence_threshold = ood_config.get('confidence_threshold', 0.5) | |
max_confidence = np.max(aggregated_probs, axis=-1) | |
if max_confidence.ndim == 0: | |
max_confidence = float(max_confidence) | |
else: | |
max_confidence = float(max_confidence[0]) | |
# Uncertainty-based detection | |
entropy_threshold = ood_config.get('entropy_threshold', 1.0) | |
variance_threshold = ood_config.get('variance_threshold', 0.1) | |
if predictive_entropy.ndim == 0: | |
entropy_value = float(predictive_entropy) | |
variance_value = float(prediction_variance) | |
else: | |
entropy_value = float(predictive_entropy[0]) | |
variance_value = float(prediction_variance[0]) | |
# Energy-based detection | |
energy_value = None | |
energy_ood = False | |
if ood_config.get('use_energy', False): | |
energy_value = float(compute_energy_score(prediction_list_logits[0])[0] | |
if compute_energy_score(prediction_list_logits[0]).ndim > 0 | |
else compute_energy_score(prediction_list_logits[0])) | |
energy_threshold = ood_config.get('energy_threshold') | |
energy_ood = energy_value > energy_threshold | |
# Determine OOD status | |
confidence_ood = max_confidence < confidence_threshold | |
entropy_ood = entropy_value > entropy_threshold | |
variance_ood = variance_value > variance_threshold | |
# Combine criteria | |
use_uncertainty = ood_config.get('use_uncertainty', True) | |
use_energy = ood_config.get('use_energy', True) | |
if use_uncertainty and use_energy: | |
is_ood = confidence_ood or entropy_ood or variance_ood or energy_ood | |
reasons = [] | |
if confidence_ood: | |
reasons.append(f"Low confidence: {max_confidence:.4f} < {confidence_threshold}") | |
if entropy_ood: | |
reasons.append(f"High entropy: {entropy_value:.4f} > {entropy_threshold}") | |
if variance_ood: | |
reasons.append(f"High variance: {variance_value:.4f} > {variance_threshold}") | |
if energy_ood: | |
reasons.append(f"High energy: {energy_value:.4f} > {energy_threshold}") | |
if reasons: | |
reason = "; ".join(reasons) | |
else: | |
reason = f"In-distribution: conf={max_confidence:.4f}, entropy={entropy_value:.4f}, var={variance_value:.4f}, energy={energy_value:.4f}" | |
elif use_energy: | |
is_ood = energy_ood or confidence_ood | |
if is_ood: | |
if energy_ood: | |
reason = f"High energy: {energy_value:.4f} > {energy_threshold}" | |
else: | |
reason = f"Low confidence: {max_confidence:.4f} < {confidence_threshold}" | |
else: | |
reason = f"Low energy: {energy_value:.4f} <= {energy_threshold}" | |
elif use_uncertainty: | |
is_ood = confidence_ood or entropy_ood or variance_ood | |
reasons = [] | |
if confidence_ood: | |
reasons.append(f"Low confidence: {max_confidence:.4f} < {confidence_threshold}") | |
if entropy_ood: | |
reasons.append(f"High entropy: {entropy_value:.4f} > {entropy_threshold}") | |
if variance_ood: | |
reasons.append(f"High variance: {variance_value:.4f} > {variance_threshold}") | |
if reasons: | |
reason = "; ".join(reasons) | |
else: | |
reason = f"In-distribution: conf={max_confidence:.4f}, entropy={entropy_value:.4f}, var={variance_value:.4f}" | |
else: | |
is_ood = confidence_ood | |
if is_ood: | |
reason = f"Low confidence: {max_confidence:.4f} < {confidence_threshold}" | |
else: | |
reason = f"High confidence: {max_confidence:.4f} >= {confidence_threshold}" | |
uncertainty_metrics = { | |
'predictive_entropy': entropy_value, | |
'prediction_variance': variance_value, | |
'max_confidence': max_confidence, | |
'num_perturbations': len(predictions_list) | |
} | |
if energy_value is not None: | |
uncertainty_metrics['energy_score'] = energy_value | |
return is_ood, reason, aggregated_probs, uncertainty_metrics | |
def predict_single_image_with_perturbation(model, img_batch: np.ndarray, | |
class_names: List[str], | |
model_type: str, | |
ood_config: Optional[Dict[str, Any]] = DEFAULT_OOD_CONFIG, | |
perturbation_config: Optional[Dict[str, Any]] = None) -> Tuple[str, float, bool, str, Dict[str, float]]: | |
""" | |
Make prediction on a single preprocessed image with perturbation-based uncertainty estimation | |
Args: | |
model: Trained model | |
img_batch: Preprocessed image batch | |
class_names: List of class names | |
model_type: "keras" or "savedmodel" | |
ood_config: OOD detection configuration | |
perturbation_config: Input perturbation configuration | |
Returns: | |
Tuple of (predicted_class, confidence, ood_detected, ood_reason, uncertainty_metrics) | |
""" | |
try: | |
# Default configurations | |
if ood_config is None: | |
ood_config = {'enabled': False} | |
if perturbation_config is None: | |
perturbation_config = {'gaussian_noise': {'enabled': False}} | |
# Generate perturbations | |
perturbations = generate_perturbations(img_batch, perturbation_config) | |
# Get predictions for all perturbations | |
all_predictions = [] | |
all_predictions_logits = [] | |
for perturbed_input in perturbations: | |
# Convert to tensor if needed | |
if model_type == "keras": | |
predictions = model(perturbed_input, training=False) | |
else: # SavedModel | |
if hasattr(model, 'signatures'): | |
if 'serving_default' in model.signatures: | |
serving_fn = model.signatures['serving_default'] | |
else: | |
signature_key = list(model.signatures.keys())[0] | |
serving_fn = model.signatures[signature_key] | |
input_tensor = tf.convert_to_tensor(perturbed_input, dtype=tf.float32) | |
input_signature = serving_fn.structured_input_signature[1] | |
input_key = list(input_signature.keys())[0] | |
prediction_dict = serving_fn(**{input_key: input_tensor}) | |
if isinstance(prediction_dict, dict): | |
predictions = list(prediction_dict.values())[0] | |
else: | |
predictions = prediction_dict | |
else: | |
predictions = model(tf.convert_to_tensor(perturbed_input, dtype=tf.float32)) | |
# Convert to numpy and ensure proper shape | |
if hasattr(predictions, 'numpy'): | |
predictions_np = predictions.numpy() | |
else: | |
predictions_np = np.array(predictions) | |
if predictions_np.ndim == 1: | |
predictions_np = predictions_np.reshape(1, -1) | |
all_predictions_logits.append(predictions_np) | |
# Apply temperature scaling if configured and input is logits | |
temperature = ood_config.get('temperature', 1.0) | |
if is_logits(predictions_np) and temperature != 1.0: | |
predictions_np = softmax_with_temperature(predictions_np, temperature) | |
elif not is_logits(predictions_np) and temperature != 1.0: | |
warnings.warn("Cannot apply temperature scaling to probabilities. Need logits!") | |
elif is_logits(predictions_np): | |
predictions_np = softmax_with_temperature(predictions_np, 1.0) | |
all_predictions.append(predictions_np) | |
# Enhanced OOD detection with uncertainty | |
ood_detected, ood_reason, final_probs, uncertainty_metrics = detect_ood_with_perturbation(all_predictions_logits, | |
all_predictions, ood_config | |
) | |
# Get final prediction | |
predicted_class_index = np.argmax(final_probs, axis=-1)[0] | |
final_confidence = np.max(final_probs, axis=-1)[0] | |
# Get class name | |
if class_names and predicted_class_index < len(class_names): | |
predicted_class = class_names[predicted_class_index] | |
else: | |
predicted_class = f"Class_{predicted_class_index}" | |
return predicted_class, float(final_confidence), ood_detected, ood_reason, uncertainty_metrics | |
except Exception as e: | |
print(f"Error during prediction: {e}") | |
print(f"Model type: {type(model)}") | |
if hasattr(model, 'signatures'): | |
print(f"Available signatures: {list(model.signatures.keys())}") | |
for sig_name, sig in model.signatures.items(): | |
print(f"Signature '{sig_name}':") | |
print(f" Inputs: {sig.structured_input_signature}") | |
print(f" Outputs: {sig.structured_outputs}") | |
return None, None, False, "Prediction error", {} | |
def create_results_visualization(results, config, ood_config, output_filename=None): | |
"""Create a visualization table of prediction results""" | |
if not results: | |
print("No results to visualize!") | |
return None | |
n_images = len(results) | |
cols = min(6, n_images) # Max 4 columns | |
rows = math.ceil(n_images / cols) | |
fig_width = cols * 6 | |
fig_height = rows * 5 | |
fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height)) | |
# Handle single row/column cases | |
if rows == 1 and cols == 1: | |
axes = [axes] | |
elif rows == 1: | |
axes = [axes] | |
elif cols == 1: | |
axes = [[ax] for ax in axes] | |
else: | |
# axes is already 2D | |
pass | |
if rows > 1: | |
axes_flat = [ax for row in axes for ax in row] | |
else: | |
axes_flat = axes if isinstance(axes, list) else [axes] | |
for i, (image_path, img_pil, pred_class, confidence, is_ood) in enumerate(results): | |
if i >= len(axes_flat): | |
break | |
ax = axes_flat[i] | |
ax.imshow(img_pil) | |
ax.axis('off') | |
filename = os.path.basename(image_path) | |
confidence_pct = confidence * 100 | |
if confidence >= 0.9: | |
title_color = 'green' | |
status = 'High Conf' | |
elif confidence >= 0.7: | |
title_color = 'orange' | |
status = 'Med Conf' | |
else: | |
title_color = 'red' | |
status = 'Low Conf' | |
if is_ood: | |
status = 'OOD' | |
title = f"{pred_class}\n{confidence_pct:.1f}% ({status})" | |
ax.set_title(title, fontsize=10, color=title_color, weight='bold') | |
# Add border color based on status | |
for spine in ax.spines.values(): | |
spine.set_edgecolor(title_color) | |
spine.set_linewidth(3) | |
# Hide unused subplots | |
for i in range(len(results), len(axes_flat)): | |
axes_flat[i].axis('off') | |
# Add overall title | |
model_name = config['model_name'].upper() | |
img_size = config['img_size'] | |
total_images = len(results) | |
ood_count = sum(1 for _, _, _, _, is_ood in results if is_ood) | |
avg_confidence = np.mean([conf for _, _, _, conf, _ in results]) * 100 | |
title_text = f"{model_name} Model Results (Image Size: {img_size}x{img_size})\n" | |
title_text += f"Total Images: {total_images} | OOD Detected: {ood_count} | Avg Confidence: {avg_confidence:.1f}%" | |
if ood_config['enabled']: | |
title_text += f"\nOOD Detection: ON (Temp: {ood_config['temperature']}, Threshold: {ood_config['threshold']})" | |
else: | |
title_text += "\nOOD Detection: OFF" | |
fig.suptitle(title_text, fontsize=14, weight='bold', y=0.98) | |
# Add legend | |
legend_elements = [ | |
patches.Patch(color='green', label='High Confidence (≥90%)'), | |
patches.Patch(color='orange', label='Medium Confidence (70-89%)'), | |
patches.Patch(color='red', label='Low Confidence (<70%) / OOD') | |
] | |
fig.legend(handles=legend_elements, loc='lower center', ncol=3, bbox_to_anchor=(0.5, 0.02)) | |
# plt.tight_layout() | |
plt.subplots_adjust(top=0.9, bottom=0.1) | |
if output_filename is None: | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_filename = f"prediction_results_{model_name.lower()}_{timestamp}.png" | |
plt.savefig(output_filename, dpi=300, bbox_inches='tight', | |
facecolor='white', edgecolor='none') | |
print(f"Results visualization saved as: {output_filename}") | |
# plt.show() | |
return output_filename | |
def generate_prediction_report(): | |
"""Generate prediction reports for each image folder in a base directory or a single image folder.""" | |
print("=" * 60) | |
print("PREDICTION RESULTS VISUALIZATION GENERATOR") | |
print("=" * 60) | |
# Select model | |
config = select_model_for_testing() | |
if config is None: | |
return | |
# Configure OOD detection | |
print("\nConfiguring OOD Detection...") | |
enable_str = input(f"Enable OOD detection? (y/n, default: {'y' if ENABLE_OOD_DETECTION else 'n'}): ").strip().lower() | |
if enable_str == '': | |
enable_ood = ENABLE_OOD_DETECTION | |
else: | |
enable_ood = enable_str in ['y', 'yes', '1', 'true'] | |
ood_config = {'enabled': enable_ood} | |
if enable_ood: | |
temp_str = input(f"Temperature parameter (default: {DEFAULT_TEMPERATURE}): ").strip() | |
try: | |
temperature = float(temp_str) if temp_str else DEFAULT_TEMPERATURE | |
except ValueError: | |
temperature = DEFAULT_TEMPERATURE | |
thresh_str = input(f"Confidence threshold (default: {DEFAULT_CONFIDENCE_THRESHOLD}): ").strip() | |
try: | |
threshold = float(thresh_str) if thresh_str else DEFAULT_CONFIDENCE_THRESHOLD | |
except ValueError: | |
threshold = DEFAULT_CONFIDENCE_THRESHOLD | |
# Add energy-based OOD detection configuration | |
enable_energy_str = input(f"Enable energy-based OOD detection? (y/n, default: {'y' if ood_config['use_energy'] else 'n'}): ").strip().lower() | |
if enable_energy_str == '': | |
enable_energy = ood_config['use_energy'] | |
else: | |
enable_energy = enable_energy_str in ['y', 'yes', '1', 'true'] | |
if enable_energy: | |
energy_thresh_str = input(f"Energy threshold (default: {ood_config['energy_threshold']}, higher means stricter OOD): ").strip() | |
try: | |
energy_threshold = float(energy_thresh_str) if energy_thresh_str else ood_config['energy_threshold'] | |
except ValueError: | |
energy_threshold = 10.0 | |
ood_config.update({ | |
'use_energy': True, | |
'energy_threshold': energy_threshold | |
}) | |
ood_config.update({ | |
'temperature': temperature, | |
'confidence_threshold': threshold | |
}) | |
# Load model | |
model, class_names, model_type, config = load_model_and_classes(config) | |
if model is None: | |
print("Failed to load model!") | |
return | |
# Prompt for base directory or image folder | |
base_path = input("Enter the base directory path or a single image folder: ").strip() | |
base_path = TEST_DIR if base_path == "" or base_path is None else base_path | |
if not os.path.isdir(base_path): | |
print(f"Directory not found: {base_path}") | |
return | |
img_size = config['img_size'] | |
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG'] | |
max_images = 100 | |
# Check if base_path itself contains images | |
image_paths = [] | |
for ext in image_extensions: | |
image_paths.extend(glob.glob(os.path.join(base_path, ext))) | |
image_paths = sorted(image_paths, key=lambda x: os.path.basename(x).lower()) | |
if image_paths: # base_path is an image folder | |
print(f"\nDetected {len(image_paths)} images in selected folder: {os.path.basename(base_path)}") | |
folder_name = os.path.basename(base_path.rstrip('/\\')) | |
if len(image_paths) > max_images: | |
print(f"Limiting to first {max_images} images for visualization") | |
image_paths = image_paths[:max_images] | |
results = [] | |
print("Processing images...") | |
for i, image_path in enumerate(image_paths): | |
print(f"Processing {i+1}/{len(image_paths)}: {os.path.basename(image_path)}") | |
img_batch, img_pil = preprocess_image(image_path, img_size, 'bicubic') | |
if img_batch is None: | |
continue | |
pred_class, confidence, ood_detected, ood_reason, uncertainty_metrics = predict_single_image_with_perturbation( | |
model, img_batch, class_names, model_type, ood_config, perturbation_config=DEFAULT_PERTURBATION_CONFIG | |
) | |
if pred_class is not None: | |
results.append((image_path, img_pil, pred_class, confidence, ood_detected)) | |
if not results: | |
print(f"No successful predictions in {folder_name} to visualize!") | |
return | |
output_file = f"{folder_name}_results.png" | |
output_file = os.path.join("visualization", output_file) | |
print(f"\nGenerating visualization for {len(results)} predictions in {folder_name}...") | |
create_results_visualization(results, config, ood_config, output_filename=output_file) | |
print(f"Saved visualization: {output_file}") | |
# Print summary statistics | |
print("\n" + "=" * 60) | |
print(f"SUMMARY STATISTICS for {folder_name}") | |
print("=" * 60) | |
confidences = [conf for _, _, _, conf, _ in results] | |
ood_detections = sum(1 for _, _, _, _, is_ood in results if is_ood) | |
print(f"Total predictions: {len(results)}") | |
print(f"Average confidence: {np.mean(confidences)*100:.1f}%") | |
print(f"Highest confidence: {np.max(confidences)*100:.1f}%") | |
print(f"Lowest confidence: {np.min(confidences)*100:.1f}%") | |
print(f"OOD detections: {ood_detections}/{len(results)} ({100*ood_detections/len(results):.1f}%)") | |
classes = [pred_class for _, _, pred_class, _, _ in results] | |
unique_classes = list(set(classes)) | |
print(f"\nPredicted classes:") | |
for cls in unique_classes: | |
count = classes.count(cls) | |
print(f" {cls}: {count} ({100*count/len(results):.1f}%)") | |
return | |
# Otherwise, treat as base directory containing subfolders | |
subfolders = [os.path.join(base_path, d) for d in os.listdir(base_path) | |
if os.path.isdir(os.path.join(base_path, d))] | |
if not subfolders: | |
print(f"No subfolders found in {base_path}") | |
return | |
print(f"Found {len(subfolders)} image folders to process") | |
for folder_path in subfolders: | |
folder_name = os.path.basename(folder_path.rstrip('/\\')) | |
image_paths = [] | |
for ext in image_extensions: | |
image_paths.extend(glob.glob(os.path.join(folder_path, ext))) | |
image_paths = sorted(image_paths, key=lambda x: os.path.basename(x).lower()) | |
if not image_paths: | |
print(f"No images found in folder {folder_name}") | |
continue | |
print(f"\nProcessing folder: {folder_name}") | |
print(f"Found {len(image_paths)} images in folder {folder_name}") | |
if len(image_paths) > max_images: | |
print(f"Limiting to first {max_images} images for visualization") | |
image_paths = image_paths[:max_images] | |
results = [] | |
print("Processing images...") | |
for i, image_path in enumerate(image_paths): | |
print(f"Processing {i+1}/{len(image_paths)}: {os.path.basename(image_path)}") | |
img_batch, img_pil = preprocess_image(image_path, img_size, 'bicubic') | |
if img_batch is None: | |
continue | |
pred_class, confidence, ood_detected, ood_reason, uncertainty_metrics = predict_single_image_with_perturbation( | |
model, img_batch, class_names, model_type, ood_config, perturbation_config=DEFAULT_PERTURBATION_CONFIG | |
) | |
if pred_class is not None: | |
results.append((image_path, img_pil, pred_class, confidence, ood_detected)) | |
if not results: | |
print(f"No successful predictions in {folder_name} to visualize!") | |
continue | |
output_file = f"{folder_name}_results.png" | |
output_file = os.path.join("visualization", output_file) | |
print(f"\nGenerating visualization for {len(results)} predictions in {folder_name}...") | |
create_results_visualization(results, config, ood_config, output_filename=output_file) | |
print(f"Saved visualization: {output_file}") | |
# Print summary statistics for this folder | |
print("\n" + "=" * 60) | |
print(f"SUMMARY STATISTICS for {folder_name}") | |
print("=" * 60) | |
confidences = [conf for _, _, _, conf, _ in results] | |
ood_detections = sum(1 for _, _, _, _, is_ood in results if is_ood) | |
print(f"Total predictions: {len(results)}") | |
print(f"Average confidence: {np.mean(confidences)*100:.1f}%") | |
print(f"Highest confidence: {np.max(confidences)*100:.1f}%") | |
print(f"Lowest confidence: {np.min(confidences)*100:.1f}%") | |
print(f"OOD detections: {ood_detections}/{len(results)} ({100*ood_detections/len(results):.1f}%)") | |
classes = [pred_class for _, _, pred_class, _, _ in results] | |
unique_classes = list(set(classes)) | |
print(f"\nPredicted classes:") | |
for cls in unique_classes: | |
count = classes.count(cls) | |
print(f" {cls}: {count} ({100*count/len(results):.1f}%)") | |
print("\nAll folders processed.") | |
if __name__ == "__main__": | |
generate_prediction_report() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment