Created
July 11, 2019 09:13
-
-
Save qianyizhang/2fbd0c72024dd0a0fd37cd458aa8ee8f to your computer and use it in GitHub Desktop.
replacement of scaled_l2 and aggregate in PyTorch-Encoding/encoding/functions/encoding.py with pure torch ops
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def scaled_l2(X, C, S): | |
""" | |
scaled_l2 distance | |
Args: | |
X (b*n*d): original feature input | |
C (k*d): code words, with k codes, each with d dimension | |
S (k): scale cofficient | |
Return: | |
D (b*n*k): relative distance to each code | |
Note: | |
apparently the X^2 + C^2 - 2XC computation is 2x faster than | |
elementwise sum, perhaps due to friendly cache in gpu | |
""" | |
assert X.shape[-1] == C.shape[-1], "input, codeword feature dim mismatch" | |
assert S.numel() == C.shape[0], "scale, codeword num mismatch" | |
""" | |
# simplier but slower | |
X = X.unsqueeze(2) | |
C = C[None, None,...] | |
norm = torch.norm(X-C, dim=-1).pow(2.0) | |
scaled_norm = S * norm | |
""" | |
b, n, d = X.shape | |
X = X.view(-1, d) # [bn, d] | |
Ct = C.t() # [d, k] | |
X2 = X.pow(2.0).sum(-1, keepdim=True) # [bn, 1] | |
C2 = Ct.pow(2.0).sum(0, keepdim=True) # [1, k] | |
norm = X2 + C2 - 2.0 * X.mm(Ct) # [bn, k] | |
scaled_norm = S * norm | |
D = scaled_norm.view(b, n, -1) # [b, n, k] | |
return D | |
def aggregate(A, X, C): | |
""" | |
aggregate residuals from N samples | |
Args: | |
A (b*n*k): weight of each feature contribute to code residual | |
X (b*n*d): original feature input | |
C (k*d): code words, with k codes, each with d dimension | |
Return: | |
E (b*k*d): residuals to each code | |
""" | |
assert X.shape[-1] == C.shape[-1], "input, codeword feature dim mismatch" | |
assert A.shape[:2] == X.shape[:2], "weight, input dim mismatch" | |
X = X.unsqueeze(2) # [b, n, d] -> [b, n, 1, d] | |
C = C[None, None, ...] # [k, d] -> [1, 1, k, d] | |
A = A.unsqueeze(-1) # [b, n, k] -> [b, n, k, 1] | |
R = (X - C) * A # [b, n, k, d] | |
E = R.sum(dim=1) # [b, k, d] | |
return E |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment