Last active
July 19, 2025 18:54
-
-
Save vadimkantorov/ea989c75f79961fe46182845b40d5f31 to your computer and use it in GitHub Desktop.
Base64 decoding in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://en.wikipedia.org/wiki/Base64 | |
# 00123456 00ABCDEF 00abcdef 00uvwxyz | |
# 123456AB CDEFabcd efuvwxyz | |
# this code does not support batches. adapting for e.g. concatenated varlen format is possible, but need to handle/preserve varlen information and paddings in some way | |
import torch | |
def base64_encode_padded(input_as_uint8_tensor): | |
base64_alphabet, base64_pad = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/', '=' | |
encode = torch.tensor(list(map(ord, base64_alphabet)), dtype = torch.uint8, device = input_as_uint8_tensor.device) | |
pad = torch.zeros(2, dtype = torch.uint8, device = input_as_uint8_tensor.device) | |
masks = torch.tensor([[0b00000011], [0b00001111], [0b00111111]], dtype = torch.uint8, device = input_as_uint8_tensor.device) | |
shifts = torch.tensor([[4], [2], [0], [2], [4], [6]], dtype = torch.uint8, device = input_as_uint8_tensor.device) | |
mod = input_as_uint8_tensor.shape[0] % 3 | |
res = (input_as_uint8_tensor if mod == 0 else torch.cat([input_as_uint8_tensor, pad[:(3 - mod)]])).view(-1, 3) | |
resT = res.T.contiguous() | |
#res = torch.stack([(resT[0] >> 2), ((resT[0] & 0b00000011) << 4) | (resT[1] >> 4), ((resT[1] & 0b00001111) << 2) | (resT[2] >> 6), (resT[2] & 0b00111111)], dim = -1) | |
res1 = (resT & masks) << shifts[:3] | |
res2 = resT >> shifts[3:] | |
res = torch.stack([res2[0], res1[0] | res2[1], res1[1] | res2[2], res1[2]], dim = -1) | |
res = encode[res.to(torch.int32).view(-1)] | |
if mod > 0: | |
res[-(3 - mod):] = ord(base64_pad) | |
return res | |
def base64_decode_padded(base64_as_uint8_tensor): | |
base64_alphabet, base64_pad = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/', '=' | |
decode = torch.zeros(256, dtype = torch.uint8, device = base64_as_uint8_tensor.device).put_(torch.tensor(list(map(ord, base64_alphabet)), device = base64_as_uint8_tensor.device), torch.tensor(list(range(len(base64_alphabet))), dtype = torch.uint8, device = base64_as_uint8_tensor.device)) | |
shifts = torch.tensor([[2], [4], [6], [4], [2], [0]], dtype = torch.uint8, device = base64_as_uint8_tensor.device) | |
# still can't index with dtypes lower than int32 :( https://github.com/pytorch/pytorch/issues/61819#issuecomment-3089865206 | |
res = decode[base64_as_uint8_tensor.to(torch.int32)].view(-1, 4) | |
resT = res.T.contiguous() | |
lsh = resT[:-1] << shifts[:3] | |
rsh = resT[ 1:] >> shifts[3:] | |
res = (lsh | rsh).T.contiguous().view(-1) | |
#res = torch.stack([(resT[0] << 2) | (resT[1] >> 4), (resT[1] << 4) | (resT[2] >> 2), (resT[2] << 6) | (resT[3] >> 0)], dim = -1).view(-1) | |
unpad = (base64_as_uint8_tensor[-2:] == ord(base64_pad)).sum() | |
res = res[:res.shape[0] - unpad] | |
return res | |
if __name__ == '__main__': | |
input_base64_str1 = 'bGlnaHQgd29yay4=' | |
input_base64_str2 = 'bGlnaHQgd29yaw==' | |
input_base64_str3 = 'bGlnaHQgd29y' | |
for input_base64_str in [input_base64_str1, input_base64_str2, input_base64_str3]: | |
base64_as_uint8_tensor = torch.tensor(list(map(ord, input_base64_str)), dtype = torch.uint8) | |
decoded_as_uint8_tensor = base64_decode_padded(base64_as_uint8_tensor) | |
decoded_input_str = ''.join(map(chr, decoded_as_uint8_tensor.tolist())) | |
print('"', input_base64_str, '" "', decoded_input_str, '"', sep='') | |
encoded_as_uint8_tensor = base64_encode_padded(decoded_as_uint8_tensor) | |
encoded_str = ''.join(map(chr, encoded_as_uint8_tensor.tolist())) | |
print('"', input_base64_str, '" "', encoded_str, '"', sep='') | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment