Created
November 1, 2024 13:27
-
-
Save jfeaver/e67eaf5bb3a1f72ff7eb7d6563f380dc to your computer and use it in GitHub Desktop.
An all purpose base64 encoder for URLs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""An all purpose base64 encoder for URLs. | |
It even does emojis! You commonly will use string_to_base64_url and string_from_base64_url. | |
Adapted from https://github.com/supabase-community/base64url-js. | |
""" | |
from collections.abc import Callable | |
BITS_PER_BYTE = 8 | |
SINGLE_BYTE_MASK = 0b11111111 | |
MAX_ASCII_VALUE = 0b1111111 # 7 bits | |
TWO_BYTE_CHARACTER_MAX = 0b11111111111 # 11 bits | |
THREE_BYTE_CHARACTER_MAX = 0b1111111111111111 # 16 bits | |
UNICODE_MAX_CODEPOINT = 0x10FFFF | |
# This bit mask is used to indicate the start of continuation bytes in UTF-8. It signifies that | |
# the byte is not standalone but rather part of a multi-byte character. | |
CONTINUATION_BYTE_MASK = 0b10000000 # 8th bit | |
# These are the starting byte markers for 2, 3, and 4 byte UTF-8 characters, respectively. | |
TWO_BYTE_UTF8_PREFIX = 0b11000000 # 7th and 8th bits | |
THREE_BYTE_UTF8_PREFIX = 0b11100000 # 6-8th bits | |
FOUR_BYTE_UTF8_PREFIX = 0b11110000 # 5-8th bits | |
# For extracting the relevant bits of each byte | |
UTF_8_PAYLOAD_MASK = 0b111111 # six bits | |
UTF_8_PAYLOAD_BIT_COUNT = 6 | |
# These values are used as masks when working with 2-, 3-, and 4-byte UTF-8 encodings, | |
# respectively, to remove prefix bits. | |
TWO_BYTE_UTF8_MASK = 0b11111 # five bits | |
THREE_BYTE_UTF8_MASK = 0b1111 | |
FOUR_BYTE_UTF8_MASK = 0b111 | |
# These values relate to UTF-8 surrogate pairs. They are used to encode characters beyond the | |
# Basic Multilingual Plane (BMP) by combining high and low surrogate values. | |
HIGH_SURROGATE_START = 0xD800 | |
HIGH_SURROGATE_END = 0xDBFF | |
LOW_SURROGATE_START = 0xDC00 | |
# An array of characters that encodes 6 bits into a Base64-URL alphabet character. | |
TO_BASE64URL = list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") | |
# An array of characters that can appear in a Base64-URL encoded string but should be ignored. | |
IGNORE_BASE64URL = list(" \t\n\r=") | |
# An array of 128 numbers that map a Base64-URL character to 6 bits; | |
# use -2 to skip the character, or -1 to indicate an error. | |
FROM_BASE64URL = [-1] * 128 | |
for char in IGNORE_BASE64URL: | |
FROM_BASE64URL[ord(char)] = -2 | |
for i, char in enumerate(TO_BASE64URL): | |
FROM_BASE64URL[ord(char)] = i | |
def byte_to_base64_url(byte: int | None, state: dict[str, int], emit: Callable) -> None: | |
"""Convert a byte to a Base64-URL string. | |
Modifies `state` in place. | |
Args: | |
---- | |
byte: The byte to convert, or None to flush at the end of the byte sequence. | |
state: The Base64 conversion state. Pass an initial value of `{ queue: 0, queuedBits: 0 }`. | |
emit: A function called with the next Base64 character when ready. | |
""" | |
if byte is not None: | |
state["queue"] = (state["queue"] << BITS_PER_BYTE) | byte | |
state["queued_bits"] += BITS_PER_BYTE | |
while state["queued_bits"] >= UTF_8_PAYLOAD_BIT_COUNT: | |
pos = ( | |
state["queue"] >> (state["queued_bits"] - UTF_8_PAYLOAD_BIT_COUNT) | |
) & UTF_8_PAYLOAD_MASK | |
emit(TO_BASE64URL[pos]) | |
state["queued_bits"] -= UTF_8_PAYLOAD_BIT_COUNT | |
elif state["queued_bits"] > 0: | |
state["queue"] <<= UTF_8_PAYLOAD_BIT_COUNT - state["queued_bits"] | |
state["queued_bits"] = UTF_8_PAYLOAD_BIT_COUNT | |
while state["queued_bits"] >= UTF_8_PAYLOAD_BIT_COUNT: | |
pos = ( | |
state["queue"] >> (state["queued_bits"] - UTF_8_PAYLOAD_BIT_COUNT) | |
) & UTF_8_PAYLOAD_MASK | |
emit(TO_BASE64URL[pos]) | |
state["queued_bits"] -= UTF_8_PAYLOAD_BIT_COUNT | |
def byte_from_base64_url(char_code: int, state: dict[str, int], emit: Callable) -> None: | |
"""Convert a String char code to a sequence of Base64-URL characters. | |
Char code is extracted using `ord(string)`. Modifies `state` in place. | |
Args: | |
---- | |
char_code: The char code of the string. | |
state: The Base64 state. Pass an initial value of `{ queue: 0, queuedBits: 0 }`. | |
emit: A function called with the next byte. | |
""" | |
bits = FROM_BASE64URL[char_code] | |
if bits > -1: | |
state["queue"] = (state["queue"] << UTF_8_PAYLOAD_BIT_COUNT) | bits | |
state["queued_bits"] += UTF_8_PAYLOAD_BIT_COUNT | |
while state["queued_bits"] >= BITS_PER_BYTE: | |
next_up_bits = state["queue"] >> (state["queued_bits"] - BITS_PER_BYTE) | |
emit(next_up_bits & SINGLE_BYTE_MASK) | |
state["queued_bits"] -= BITS_PER_BYTE | |
elif bits == (UTF_8_PAYLOAD_BIT_COUNT - BITS_PER_BYTE): | |
# We've gone past the end. We're done | |
return | |
else: | |
message = f"Invalid Base64-URL character '{chr(char_code)}'" | |
raise ValueError(message) | |
def string_to_base64_url(s: str) -> str: | |
"""Convert a string (which may include any valid character) into a Base64-URL encoded string. | |
The string is first encoded in UTF-8 which is then encoded as Base64-URL. | |
Args: | |
---- | |
s: The string to convert. | |
""" | |
base64 = [] | |
state = {"queue": 0, "queued_bits": 0} | |
def emitter(char: int) -> None: | |
base64.append(char) | |
string_to_utf8(s, lambda byte: byte_to_base64_url(byte, state, emitter)) | |
byte_to_base64_url(None, state, emitter) | |
return "".join(base64) | |
def string_from_base64_url(s: str) -> str: | |
"""Convert a Base64-URL encoded string into the original string. | |
It is assumed that the underlying string has been encoded as UTF-8. | |
Args: | |
---- | |
s: The Base64-URL encoded string. | |
""" | |
conv = [] | |
utf8_state = {"utf8seq": 0, "codepoint": 0} | |
b64_state = {"queue": 0, "queued_bits": 0} | |
def utf8_emit(codepoint: int) -> None: | |
conv.append(chr(codepoint)) | |
def byte_emit(byte: int) -> None: | |
string_from_utf8(byte, utf8_state, utf8_emit) | |
for char in s: | |
byte_from_base64_url(ord(char), b64_state, byte_emit) | |
return "".join(conv) | |
def codepoint_to_utf8(codepoint: int, emit: Callable) -> None: | |
"""Convert a Unicode codepoint to a multi-byte UTF-8 sequence. | |
Args: | |
---- | |
codepoint: The Unicode codepoint. | |
emit: A function that is called for each UTF-8 byte that represents the codepoint. | |
""" | |
if codepoint <= MAX_ASCII_VALUE: | |
emit(codepoint) | |
elif codepoint <= TWO_BYTE_CHARACTER_MAX: | |
emit(TWO_BYTE_UTF8_PREFIX | (codepoint >> 6)) | |
emit(CONTINUATION_BYTE_MASK | (codepoint & UTF_8_PAYLOAD_MASK)) | |
elif codepoint <= THREE_BYTE_CHARACTER_MAX: | |
emit(THREE_BYTE_UTF8_PREFIX | (codepoint >> 12)) | |
emit(CONTINUATION_BYTE_MASK | ((codepoint >> 6) & UTF_8_PAYLOAD_MASK)) | |
emit(CONTINUATION_BYTE_MASK | (codepoint & UTF_8_PAYLOAD_MASK)) | |
elif codepoint <= UNICODE_MAX_CODEPOINT: | |
emit(FOUR_BYTE_UTF8_PREFIX | (codepoint >> 18)) | |
emit(CONTINUATION_BYTE_MASK | ((codepoint >> 12) & UTF_8_PAYLOAD_MASK)) | |
emit(CONTINUATION_BYTE_MASK | ((codepoint >> 6) & UTF_8_PAYLOAD_MASK)) | |
emit(CONTINUATION_BYTE_MASK | (codepoint & UTF_8_PAYLOAD_MASK)) | |
else: | |
message = f"Unrecognized Unicode codepoint: {hex(codepoint)}" | |
raise ValueError(message) | |
def string_to_utf8(s: str, emit: Callable) -> None: | |
"""Convert a string to a sequence of UTF-8 bytes. | |
Args: | |
---- | |
s: The string to convert to UTF-8. | |
emit: A function that is called for each UTF-8 byte of the string. | |
""" | |
i = 0 | |
while i < len(s): | |
codepoint = ord(s[i]) | |
if HIGH_SURROGATE_START <= codepoint <= HIGH_SURROGATE_END: | |
high_surrogate = (codepoint - HIGH_SURROGATE_START) * 0x400 | |
low_surrogate = ord(s[i + 1]) - LOW_SURROGATE_START | |
codepoint = (low_surrogate | high_surrogate) + 0x10000 | |
i += 1 | |
codepoint_to_utf8(codepoint, emit) | |
i += 1 | |
def string_from_utf8(byte: int, state: dict[str, int], emit: Callable) -> None: | |
"""Convert a UTF-8 byte to a Unicode codepoint. | |
Modifies `state` in place. | |
Args: | |
---- | |
byte: The next UTF-8 byte in the sequence. | |
state: The shared state between consecutive UTF-8 bytes in the sequence, as an object with | |
the shape `{ utf8seq: 0, codepoint: 0 }`. | |
emit: A function that is called for each codepoint. | |
""" | |
if state["utf8seq"] == 0: | |
if byte <= MAX_ASCII_VALUE: | |
# Already a valid ASCII character | |
emit(byte) | |
return | |
# Start a new UTF-8 sequence | |
state["utf8seq"] = _detect_utf8_sequence_length(byte) | |
state["codepoint"] = _initialize_codepoint(byte, state["utf8seq"]) | |
state["utf8seq"] -= 1 | |
elif state["utf8seq"] > 0: | |
# Continue with a multi-byte sequence | |
_continue_utf8_sequence(byte, state, emit) | |
def _detect_utf8_sequence_length(byte: int) -> int: | |
"""Determine the length of the UTF-8 sequence based on leading bits in the first byte.""" | |
for leading_bit in range(1, 6): # UTF-8 supports up to 4 bytes | |
if ((byte >> (7 - leading_bit)) & 1) == 0: | |
return leading_bit | |
message = "Invalid UTF-8 sequence" | |
raise ValueError(message) | |
def _initialize_codepoint(byte: int, sequence_length: int) -> int: | |
"""Initialize the codepoint from the leading byte of a UTF-8 sequence.""" | |
if sequence_length == 2: # noqa: PLR2004 | |
return byte & TWO_BYTE_UTF8_MASK | |
if sequence_length == 3: # noqa: PLR2004 | |
return byte & THREE_BYTE_UTF8_MASK | |
if sequence_length == 4: # noqa: PLR2004 | |
return byte & FOUR_BYTE_UTF8_MASK | |
message = "Invalid UTF-8 sequence" | |
raise ValueError(message) | |
def _continue_utf8_sequence(byte: int, state: dict[str, int], emit: Callable) -> None: | |
"""Process continuation bytes of a multi-byte UTF-8 sequence.""" | |
if byte < CONTINUATION_BYTE_MASK: | |
message = "Invalid UTF-8 continuation byte" | |
raise ValueError(message) | |
state["codepoint"] = (state["codepoint"] << UTF_8_PAYLOAD_BIT_COUNT) | ( | |
byte & UTF_8_PAYLOAD_MASK | |
) | |
state["utf8seq"] -= 1 | |
if state["utf8seq"] == 0: | |
emit(state["codepoint"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
After finishing this adaptation, it was pointed out to me that I should use the standard library. Oops! The base64 module returns the same encodings. I'll abandon this and just use
base64.urlsafe_b64encode
.