Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active July 19, 2025 18:54
Show Gist options
  • Save vadimkantorov/ea989c75f79961fe46182845b40d5f31 to your computer and use it in GitHub Desktop.
Save vadimkantorov/ea989c75f79961fe46182845b40d5f31 to your computer and use it in GitHub Desktop.
Base64 decoding in PyTorch
# https://en.wikipedia.org/wiki/Base64
# 00123456 00ABCDEF 00abcdef 00uvwxyz
# 123456AB CDEFabcd efuvwxyz
# this code does not support batches. adapting for e.g. concatenated varlen format is possible, but need to handle/preserve varlen information and paddings in some way
import torch
def base64_encode_padded(input_as_uint8_tensor):
base64_alphabet, base64_pad = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/', '='
encode = torch.tensor(list(map(ord, base64_alphabet)), dtype = torch.uint8, device = input_as_uint8_tensor.device)
pad = torch.zeros(2, dtype = torch.uint8, device = input_as_uint8_tensor.device)
masks = torch.tensor([[0b00000011], [0b00001111], [0b00111111]], dtype = torch.uint8, device = input_as_uint8_tensor.device)
shifts = torch.tensor([[4], [2], [0], [2], [4], [6]], dtype = torch.uint8, device = input_as_uint8_tensor.device)
mod = input_as_uint8_tensor.shape[0] % 3
res = (input_as_uint8_tensor if mod == 0 else torch.cat([input_as_uint8_tensor, pad[:(3 - mod)]])).view(-1, 3)
resT = res.T.contiguous()
#res = torch.stack([(resT[0] >> 2), ((resT[0] & 0b00000011) << 4) | (resT[1] >> 4), ((resT[1] & 0b00001111) << 2) | (resT[2] >> 6), (resT[2] & 0b00111111)], dim = -1)
res1 = (resT & masks) << shifts[:3]
res2 = resT >> shifts[3:]
res = torch.stack([res2[0], res1[0] | res2[1], res1[1] | res2[2], res1[2]], dim = -1)
res = encode[res.to(torch.int32).view(-1)]
if mod > 0:
res[-(3 - mod):] = ord(base64_pad)
return res
def base64_decode_padded(base64_as_uint8_tensor):
base64_alphabet, base64_pad = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/', '='
decode = torch.zeros(256, dtype = torch.uint8, device = base64_as_uint8_tensor.device).put_(torch.tensor(list(map(ord, base64_alphabet)), device = base64_as_uint8_tensor.device), torch.tensor(list(range(len(base64_alphabet))), dtype = torch.uint8, device = base64_as_uint8_tensor.device))
shifts = torch.tensor([[2], [4], [6], [4], [2], [0]], dtype = torch.uint8, device = base64_as_uint8_tensor.device)
# still can't index with dtypes lower than int32 :( https://github.com/pytorch/pytorch/issues/61819#issuecomment-3089865206
res = decode[base64_as_uint8_tensor.to(torch.int32)].view(-1, 4)
resT = res.T.contiguous()
lsh = resT[:-1] << shifts[:3]
rsh = resT[ 1:] >> shifts[3:]
res = (lsh | rsh).T.contiguous().view(-1)
#res = torch.stack([(resT[0] << 2) | (resT[1] >> 4), (resT[1] << 4) | (resT[2] >> 2), (resT[2] << 6) | (resT[3] >> 0)], dim = -1).view(-1)
unpad = (base64_as_uint8_tensor[-2:] == ord(base64_pad)).sum()
res = res[:res.shape[0] - unpad]
return res
if __name__ == '__main__':
input_base64_str1 = 'bGlnaHQgd29yay4='
input_base64_str2 = 'bGlnaHQgd29yaw=='
input_base64_str3 = 'bGlnaHQgd29y'
for input_base64_str in [input_base64_str1, input_base64_str2, input_base64_str3]:
base64_as_uint8_tensor = torch.tensor(list(map(ord, input_base64_str)), dtype = torch.uint8)
decoded_as_uint8_tensor = base64_decode_padded(base64_as_uint8_tensor)
decoded_input_str = ''.join(map(chr, decoded_as_uint8_tensor.tolist()))
print('"', input_base64_str, '" "', decoded_input_str, '"', sep='')
encoded_as_uint8_tensor = base64_encode_padded(decoded_as_uint8_tensor)
encoded_str = ''.join(map(chr, encoded_as_uint8_tensor.tolist()))
print('"', input_base64_str, '" "', encoded_str, '"', sep='')
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment