Created
July 18, 2022 14:18
-
-
Save AdityaKane2001/60e39e5ace9906004152d649dffe24ee to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy.stats import norm as dist_model | |
import numpy as np | |
import torch | |
cfg = None # training config | |
device = None # training device object | |
seen_classes = list(range(5)) # list of seen classes | |
OOD_CLASS_NUMBER = -1 | |
def fit(prob_pos_X): | |
prob_pos = [p for p in prob_pos_X]+[2-p for p in prob_pos_X] | |
pos_mu, pos_std = dist_model.fit(prob_pos) | |
return pos_mu, pos_std | |
for epoch in range(cfg.epochs): | |
model = None | |
# Train model | |
seen_train_X_predictions = [] | |
model.eval() | |
seen_train_y = [] | |
for batch in train_dl: | |
batch = [elem.to(device) for elem in batch] | |
outputs = model(batch[0]) | |
seen_train_X_predictions.append(outputs.detach()) | |
seen_train_y.append(batch[1].cpu().numpy()) | |
seen_train_X_predictions = torch.concat(seen_train_X_predictions, dim=0).detach().cpu().numpy() | |
seen_train_y = np.concatenate(seen_train_y, axis=0) | |
mu_stds = [] | |
for i in range(len(seen_classes)): | |
pos_mu, pos_std = fit(seen_train_X_predictions[seen_train_y==i, i]) | |
mu_stds.append([pos_mu, pos_std]) | |
# print(mu_stds) | |
test_X_pred = [] | |
test_y_gt = [] | |
model.eval() | |
for batch in test_dl: # included ood samples | |
batch = [elem.to(device) for elem in batch] | |
outputs = model(batch[0]) | |
test_X_pred.append(outputs.detach()) | |
test_y_gt.append(batch[1].cpu().numpy()) | |
if len(test_X_pred[-1].shape) == 1: | |
test_X_pred[-1] = test_X_pred[-1].unsqueeze(0) | |
test_X_pred = torch.concat(test_X_pred, dim=0).detach().cpu().numpy() | |
test_y_gt = np.concatenate(test_y_gt, axis = 0) | |
# print(test_X_pred.shape, test_y_gt.shape) | |
test_y_pred = [] # our final model predictions | |
for p in test_X_pred:# loop every test prediction | |
max_class = np.argmax(p)# predicted class | |
max_value = np.max(p)# predicted probability | |
threshold = max(0.5, 1. - cfg.scale * mu_stds[max_class][1])#find threshold for the predicted class | |
if max_value > threshold: | |
test_y_pred.append(max_class) #predicted probability is greater than threshold, accept | |
else: | |
test_y_pred.append(OOD_CLASS_NUMBER) #otherwise, reject | |
accuracy, fscore = calculate_metrics(test_y_gt, test_y_pred) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment