Last active
January 19, 2022 10:42
-
-
Save sherwoac/ddfd4f9e4a5e60e883c348ad81607b6e to your computer and use it in GitHub Desktop.
_get_unit_square_intercepts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _get_unit_square_intercepts(self, slopes, intercept): | |
""" | |
returns unit square intercepts for given slope (a) and intercepts (b) | |
y = ax + b | |
solves: | |
right: y = a + b | |
x = 1 | |
y = slopes + intercept | |
left: y = b | |
x = 0 | |
y = intercept | |
top: 1 = ax + b | |
x = torch.divide(1 - intercept, slopes) | |
y = 1 | |
bottom: 0 = ax + b | |
x = torch.divide(- intercept, slopes) | |
y = 0 | |
:param slopes: b x 1 | |
:param intercepts: b x 1 | |
:return: points where line intersects unit square borders: b x pts(x, y): b x 2 x 2 | |
""" | |
batches = slopes.size(0) | |
x = torch.column_stack([torch.ones(batches), | |
torch.zeros(batches), | |
torch.divide(1 - intercept, slopes), | |
torch.divide(-1 * intercept, slopes)]) | |
y = torch.column_stack([slopes + intercept, | |
intercept, | |
torch.ones(batches), | |
torch.zeros(batches)]) | |
acceptance = (y >= 0) * (y <= 1) * (x >= 0) * (x <= 1) | |
return torch.column_stack((x[acceptance].reshape(batches, 1, -1), | |
y[acceptance].reshape(batches, 1, -1))) # b x pts(x, y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment