Skip to content

Instantly share code, notes, and snippets.

@jfeaver
Created November 1, 2024 13:27
Show Gist options
  • Save jfeaver/e67eaf5bb3a1f72ff7eb7d6563f380dc to your computer and use it in GitHub Desktop.
Save jfeaver/e67eaf5bb3a1f72ff7eb7d6563f380dc to your computer and use it in GitHub Desktop.
An all purpose base64 encoder for URLs
"""An all purpose base64 encoder for URLs.
It even does emojis! You commonly will use string_to_base64_url and string_from_base64_url.
Adapted from https://github.com/supabase-community/base64url-js.
"""
from collections.abc import Callable
BITS_PER_BYTE = 8
SINGLE_BYTE_MASK = 0b11111111
MAX_ASCII_VALUE = 0b1111111 # 7 bits
TWO_BYTE_CHARACTER_MAX = 0b11111111111 # 11 bits
THREE_BYTE_CHARACTER_MAX = 0b1111111111111111 # 16 bits
UNICODE_MAX_CODEPOINT = 0x10FFFF
# This bit mask is used to indicate the start of continuation bytes in UTF-8. It signifies that
# the byte is not standalone but rather part of a multi-byte character.
CONTINUATION_BYTE_MASK = 0b10000000 # 8th bit
# These are the starting byte markers for 2, 3, and 4 byte UTF-8 characters, respectively.
TWO_BYTE_UTF8_PREFIX = 0b11000000 # 7th and 8th bits
THREE_BYTE_UTF8_PREFIX = 0b11100000 # 6-8th bits
FOUR_BYTE_UTF8_PREFIX = 0b11110000 # 5-8th bits
# For extracting the relevant bits of each byte
UTF_8_PAYLOAD_MASK = 0b111111 # six bits
UTF_8_PAYLOAD_BIT_COUNT = 6
# These values are used as masks when working with 2-, 3-, and 4-byte UTF-8 encodings,
# respectively, to remove prefix bits.
TWO_BYTE_UTF8_MASK = 0b11111 # five bits
THREE_BYTE_UTF8_MASK = 0b1111
FOUR_BYTE_UTF8_MASK = 0b111
# These values relate to UTF-8 surrogate pairs. They are used to encode characters beyond the
# Basic Multilingual Plane (BMP) by combining high and low surrogate values.
HIGH_SURROGATE_START = 0xD800
HIGH_SURROGATE_END = 0xDBFF
LOW_SURROGATE_START = 0xDC00
# An array of characters that encodes 6 bits into a Base64-URL alphabet character.
TO_BASE64URL = list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")
# An array of characters that can appear in a Base64-URL encoded string but should be ignored.
IGNORE_BASE64URL = list(" \t\n\r=")
# An array of 128 numbers that map a Base64-URL character to 6 bits;
# use -2 to skip the character, or -1 to indicate an error.
FROM_BASE64URL = [-1] * 128
for char in IGNORE_BASE64URL:
FROM_BASE64URL[ord(char)] = -2
for i, char in enumerate(TO_BASE64URL):
FROM_BASE64URL[ord(char)] = i
def byte_to_base64_url(byte: int | None, state: dict[str, int], emit: Callable) -> None:
"""Convert a byte to a Base64-URL string.
Modifies `state` in place.
Args:
----
byte: The byte to convert, or None to flush at the end of the byte sequence.
state: The Base64 conversion state. Pass an initial value of `{ queue: 0, queuedBits: 0 }`.
emit: A function called with the next Base64 character when ready.
"""
if byte is not None:
state["queue"] = (state["queue"] << BITS_PER_BYTE) | byte
state["queued_bits"] += BITS_PER_BYTE
while state["queued_bits"] >= UTF_8_PAYLOAD_BIT_COUNT:
pos = (
state["queue"] >> (state["queued_bits"] - UTF_8_PAYLOAD_BIT_COUNT)
) & UTF_8_PAYLOAD_MASK
emit(TO_BASE64URL[pos])
state["queued_bits"] -= UTF_8_PAYLOAD_BIT_COUNT
elif state["queued_bits"] > 0:
state["queue"] <<= UTF_8_PAYLOAD_BIT_COUNT - state["queued_bits"]
state["queued_bits"] = UTF_8_PAYLOAD_BIT_COUNT
while state["queued_bits"] >= UTF_8_PAYLOAD_BIT_COUNT:
pos = (
state["queue"] >> (state["queued_bits"] - UTF_8_PAYLOAD_BIT_COUNT)
) & UTF_8_PAYLOAD_MASK
emit(TO_BASE64URL[pos])
state["queued_bits"] -= UTF_8_PAYLOAD_BIT_COUNT
def byte_from_base64_url(char_code: int, state: dict[str, int], emit: Callable) -> None:
"""Convert a String char code to a sequence of Base64-URL characters.
Char code is extracted using `ord(string)`. Modifies `state` in place.
Args:
----
char_code: The char code of the string.
state: The Base64 state. Pass an initial value of `{ queue: 0, queuedBits: 0 }`.
emit: A function called with the next byte.
"""
bits = FROM_BASE64URL[char_code]
if bits > -1:
state["queue"] = (state["queue"] << UTF_8_PAYLOAD_BIT_COUNT) | bits
state["queued_bits"] += UTF_8_PAYLOAD_BIT_COUNT
while state["queued_bits"] >= BITS_PER_BYTE:
next_up_bits = state["queue"] >> (state["queued_bits"] - BITS_PER_BYTE)
emit(next_up_bits & SINGLE_BYTE_MASK)
state["queued_bits"] -= BITS_PER_BYTE
elif bits == (UTF_8_PAYLOAD_BIT_COUNT - BITS_PER_BYTE):
# We've gone past the end. We're done
return
else:
message = f"Invalid Base64-URL character '{chr(char_code)}'"
raise ValueError(message)
def string_to_base64_url(s: str) -> str:
"""Convert a string (which may include any valid character) into a Base64-URL encoded string.
The string is first encoded in UTF-8 which is then encoded as Base64-URL.
Args:
----
s: The string to convert.
"""
base64 = []
state = {"queue": 0, "queued_bits": 0}
def emitter(char: int) -> None:
base64.append(char)
string_to_utf8(s, lambda byte: byte_to_base64_url(byte, state, emitter))
byte_to_base64_url(None, state, emitter)
return "".join(base64)
def string_from_base64_url(s: str) -> str:
"""Convert a Base64-URL encoded string into the original string.
It is assumed that the underlying string has been encoded as UTF-8.
Args:
----
s: The Base64-URL encoded string.
"""
conv = []
utf8_state = {"utf8seq": 0, "codepoint": 0}
b64_state = {"queue": 0, "queued_bits": 0}
def utf8_emit(codepoint: int) -> None:
conv.append(chr(codepoint))
def byte_emit(byte: int) -> None:
string_from_utf8(byte, utf8_state, utf8_emit)
for char in s:
byte_from_base64_url(ord(char), b64_state, byte_emit)
return "".join(conv)
def codepoint_to_utf8(codepoint: int, emit: Callable) -> None:
"""Convert a Unicode codepoint to a multi-byte UTF-8 sequence.
Args:
----
codepoint: The Unicode codepoint.
emit: A function that is called for each UTF-8 byte that represents the codepoint.
"""
if codepoint <= MAX_ASCII_VALUE:
emit(codepoint)
elif codepoint <= TWO_BYTE_CHARACTER_MAX:
emit(TWO_BYTE_UTF8_PREFIX | (codepoint >> 6))
emit(CONTINUATION_BYTE_MASK | (codepoint & UTF_8_PAYLOAD_MASK))
elif codepoint <= THREE_BYTE_CHARACTER_MAX:
emit(THREE_BYTE_UTF8_PREFIX | (codepoint >> 12))
emit(CONTINUATION_BYTE_MASK | ((codepoint >> 6) & UTF_8_PAYLOAD_MASK))
emit(CONTINUATION_BYTE_MASK | (codepoint & UTF_8_PAYLOAD_MASK))
elif codepoint <= UNICODE_MAX_CODEPOINT:
emit(FOUR_BYTE_UTF8_PREFIX | (codepoint >> 18))
emit(CONTINUATION_BYTE_MASK | ((codepoint >> 12) & UTF_8_PAYLOAD_MASK))
emit(CONTINUATION_BYTE_MASK | ((codepoint >> 6) & UTF_8_PAYLOAD_MASK))
emit(CONTINUATION_BYTE_MASK | (codepoint & UTF_8_PAYLOAD_MASK))
else:
message = f"Unrecognized Unicode codepoint: {hex(codepoint)}"
raise ValueError(message)
def string_to_utf8(s: str, emit: Callable) -> None:
"""Convert a string to a sequence of UTF-8 bytes.
Args:
----
s: The string to convert to UTF-8.
emit: A function that is called for each UTF-8 byte of the string.
"""
i = 0
while i < len(s):
codepoint = ord(s[i])
if HIGH_SURROGATE_START <= codepoint <= HIGH_SURROGATE_END:
high_surrogate = (codepoint - HIGH_SURROGATE_START) * 0x400
low_surrogate = ord(s[i + 1]) - LOW_SURROGATE_START
codepoint = (low_surrogate | high_surrogate) + 0x10000
i += 1
codepoint_to_utf8(codepoint, emit)
i += 1
def string_from_utf8(byte: int, state: dict[str, int], emit: Callable) -> None:
"""Convert a UTF-8 byte to a Unicode codepoint.
Modifies `state` in place.
Args:
----
byte: The next UTF-8 byte in the sequence.
state: The shared state between consecutive UTF-8 bytes in the sequence, as an object with
the shape `{ utf8seq: 0, codepoint: 0 }`.
emit: A function that is called for each codepoint.
"""
if state["utf8seq"] == 0:
if byte <= MAX_ASCII_VALUE:
# Already a valid ASCII character
emit(byte)
return
# Start a new UTF-8 sequence
state["utf8seq"] = _detect_utf8_sequence_length(byte)
state["codepoint"] = _initialize_codepoint(byte, state["utf8seq"])
state["utf8seq"] -= 1
elif state["utf8seq"] > 0:
# Continue with a multi-byte sequence
_continue_utf8_sequence(byte, state, emit)
def _detect_utf8_sequence_length(byte: int) -> int:
"""Determine the length of the UTF-8 sequence based on leading bits in the first byte."""
for leading_bit in range(1, 6): # UTF-8 supports up to 4 bytes
if ((byte >> (7 - leading_bit)) & 1) == 0:
return leading_bit
message = "Invalid UTF-8 sequence"
raise ValueError(message)
def _initialize_codepoint(byte: int, sequence_length: int) -> int:
"""Initialize the codepoint from the leading byte of a UTF-8 sequence."""
if sequence_length == 2: # noqa: PLR2004
return byte & TWO_BYTE_UTF8_MASK
if sequence_length == 3: # noqa: PLR2004
return byte & THREE_BYTE_UTF8_MASK
if sequence_length == 4: # noqa: PLR2004
return byte & FOUR_BYTE_UTF8_MASK
message = "Invalid UTF-8 sequence"
raise ValueError(message)
def _continue_utf8_sequence(byte: int, state: dict[str, int], emit: Callable) -> None:
"""Process continuation bytes of a multi-byte UTF-8 sequence."""
if byte < CONTINUATION_BYTE_MASK:
message = "Invalid UTF-8 continuation byte"
raise ValueError(message)
state["codepoint"] = (state["codepoint"] << UTF_8_PAYLOAD_BIT_COUNT) | (
byte & UTF_8_PAYLOAD_MASK
)
state["utf8seq"] -= 1
if state["utf8seq"] == 0:
emit(state["codepoint"])
@jfeaver
Copy link
Author

jfeaver commented Nov 1, 2024

After finishing this adaptation, it was pointed out to me that I should use the standard library. Oops! The base64 module returns the same encodings. I'll abandon this and just use base64.urlsafe_b64encode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment