Last active
November 24, 2023 09:56
-
-
Save the-bass/cae9f3976866776dea17a5049013258d to your computer and use it in GitHub Desktop.
Calculating the confusion matrix between two PyTorch tensors (a batch of predictions) - Last tested with PyTorch 0.4.1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def confusion(prediction, truth): | |
""" Returns the confusion matrix for the values in the `prediction` and `truth` | |
tensors, i.e. the amount of positions where the values of `prediction` | |
and `truth` are | |
- 1 and 1 (True Positive) | |
- 1 and 0 (False Positive) | |
- 0 and 0 (True Negative) | |
- 0 and 1 (False Negative) | |
""" | |
confusion_vector = prediction / truth | |
# Element-wise division of the 2 tensors returns a new tensor which holds a | |
# unique value for each case: | |
# 1 where prediction and truth are 1 (True Positive) | |
# inf where prediction is 1 and truth is 0 (False Positive) | |
# nan where prediction and truth are 0 (True Negative) | |
# 0 where prediction is 0 and truth is 1 (False Negative) | |
true_positives = torch.sum(confusion_vector == 1).item() | |
false_positives = torch.sum(confusion_vector == float('inf')).item() | |
true_negatives = torch.sum(torch.isnan(confusion_vector)).item() | |
false_negatives = torch.sum(confusion_vector == 0).item() | |
return true_positives, false_positives, true_negatives, false_negatives |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import torch | |
from aux import confusion | |
class TestConfusion(unittest.TestCase): | |
def test_with_valid_tensors(self): | |
prediction = torch.tensor([ | |
[1], | |
[1.0], | |
[1], | |
[0], | |
[0], | |
[0], | |
[0], | |
[0], | |
[0], | |
[0] | |
]) | |
truth = torch.tensor([ | |
[1.0], | |
[1], | |
[0], | |
[0], | |
[1], | |
[0], | |
[0], | |
[1], | |
[1], | |
[1] | |
]) | |
tp, fp, tn, fn = confusion(prediction, truth) | |
self.assertEqual(tp, 2) | |
self.assertEqual(fp, 1) | |
self.assertEqual(tn, 3) | |
self.assertEqual(fn, 4) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@AminJun
Could you provide information what each row in tensor represent tp fp tn fn?