Skip to content

Instantly share code, notes, and snippets.

@iscgar
Last active March 7, 2023 05:16
Show Gist options
  • Save iscgar/b77caf9a8b4982a1002111ba46f0e701 to your computer and use it in GitHub Desktop.
Save iscgar/b77caf9a8b4982a1002111ba46f0e701 to your computer and use it in GitHub Desktop.
import struct
import itertools
from base64 import b64encode
from retrie.trie import Trie
def commonise_group(pat):
patr = list(sorted(''.join(reversed(s)) for s in pat))
common = [patr[0]]
pat_map = {}
matching = longest = len(patr[0])
for w in patr[1:]:
for i, c in enumerate(w[:len(common[-1])]):
if c != common[-1][i]:
break
else:
raise ValueError("cannot have identical patterns in group")
if i >= max(longest, len(w)) // 2:
common.append(w)
longest = max(longest, len(w))
matching = i
else:
common_word = ''.join(reversed(common[0]))
if len(common) == 1:
pat_map[common_word] = ''.join(common_word)
elif matching == longest - 1:
pat_map[common_word] = '[{}]{}'.format(''.join(w[matching:] for w in common), ''.join(common_word[-matching:]))
else:
pat_map[common_word] = '({}){}'.format('|'.join(''.join(reversed(w[matching:])) for w in common), ''.join(common_word[-matching:]))
common = [w]
matching = longest = len(w)
common_word = ''.join(reversed(common[0]))
if len(common) == 1:
pat_map[common_word] = ''.join(common_word)
elif matching == longest - 1:
pat_map[common_word] = '[{}]{}'.format(''.join(w[matching:] for w in common), ''.join(common_word[-matching:]))
else:
pat_map[common_word] = '({}){}'.format('|'.join(''.join(reversed(w[matching:])) for w in common), ''.join(common_word[-matching:]))
return [pat_map[w] for w in pat if w in pat_map]
def shorten_pat(pattern):
pat = ['']
group_stack = []
in_alt = False
in_escape = False
last_cat = -1
cats = []
group_prefix = ''
for c in pattern:
if not in_escape:
if in_alt:
cat = -1
if '0' <= c <= '9':
cat = 0
elif 'A' <= c <= 'Z':
cat = 1
elif 'a' <= c <= 'z':
cat = 2
if cat != last_cat:
cats.sort()
if len(cats) > 3 and ord(cats[-1]) - ord(cats[0]) == len(cats) - 1:
pat[-1] += '{}-{}'.format(cats[0], cats[-1])
else:
pat[-1] += ''.join(cats)
del cats[:]
if cat != -1:
cats.append(c)
last_cat = cat
continue
if c == '\\':
in_escape = True
elif c == '[':
assert not in_alt
in_alt = True
elif c == ']':
assert in_alt
in_alt = False
elif not in_alt:
if c == '(':
group_stack.append((group_prefix, pat))
group_prefix, pat = '', ['']
continue
elif c == ')':
gp, opt = group_prefix, commonise_group(pat)
group_prefix, pat = group_stack.pop(-1)
pat[-1] += '({}{})'.format(gp, '|'.join(opt))
continue
elif c == '|':
pat.append('')
continue
elif c == ':' and group_stack and pat[-1] == '?':
assert not group_prefix
pat[-1] = ''
group_prefix = '?:'
continue
else:
in_escape = False
pat[-1] += c
assert not in_alt and not in_escape and not group_stack
assert len(pat) == 1 or pat[-1]
return '|'.join(commonise_group(pat))
def regex_for_prefix(prefix, tail_len):
# A sequence of all possible binary values
# (used to pad the prefix on either side to account for encoding alignment)
padding = bytearray(range(256))
# We build a trie in order to try to get the most compressed form of the resulting pattern
t = Trie()
# A base64 encoding block is 3 bytes long, so we need to account for the position
# of the beginning of the prefix in any of an encoding block's slots
for i in range(3):
lead = b'A' * max(0, i - 1)
# If the length of the prefix plus the current encoding block offset
# isn't divisable by the length of n encoding block, we need to pad it
# in order to get all of the values that could appear after the prefix
# in the encoded form
pad_len = (3 - (len(prefix) + i) % 3) % 3
pads = b'A' * max(0, pad_len - 1)
# Iterate over all of the permutations of padding values for this slot
for r in itertools.permutations(padding, int(bool(i)) + int(bool(pad_len))):
source = lead
if i:
source += struct.pack('<B', r[0])
source += prefix
if pad_len:
source += struct.pack('<B', r[-1]) + pads
# We get the encoded value of the prefix (offset by the current slot
# index and padded to the next encoding block boundary)
encoded = b64encode(source)
# However, if the prefix isn't at the beginning of an encoding block,
# we only care about the way it affects the encoded prefix itself,
# and we don't really care about the value of the bytes that come
# before it, so strip the leading bytes (note that since the encoded
# length is stricktly bigger than the source length for base64,
# stripping an amount equal to the slot index is guaranteed to only
# strip the leading padding bytes, but not the encoded prefix).
encoded = encoded[i:]
# Similarly, if we added padding, we only care about the way it affect
# the prefix, but not about the encoded padding byte values, so strip
# them as well (again, this is guaranteed to not touch the encoded prefix,
# because the encoded size is strictly bigger than the source size for
# base64).
if pad_len > 0:
encoded = encoded[:-pad_len]
# Add it to the trie
t.add(encoded.decode('ascii'))
# Extract a pattern that describes this trie and optimise it a bit
pat = shorten_pat(t.pattern())
# Add a pattern for the tail (because we need to at least see this many bytes as well)
total_len = len(prefix) + tail_len
left = total_len - (len(prefix) + 2)
if left > 0:
groups = (left + 3) // 4
pat += '(?:{})'.format('|'.join(r'[\+\/A-Za-z0-9]{{{}}}{}'.format(groups * 4 - i, '='*i) for i in range(3)))
return pat
@gofri
Copy link

gofri commented Mar 7, 2023

p.s. I'm using the following script to generate test vectors (obviously, not a perfect one, but helpful enough):

#!/usr/bin/env python3
import random

def range_chrs(a, z):
    return [chr(x) for x in range(ord(a), ord(z)+1)]

def pat_opts():
    return range_chrs('a', 'z') + range_chrs('A', 'Z') + range_chrs('0', '9')

def ascii_opts():
    return range_chrs('0', 'z')

def gen_one(pref, tail):
    s = ''
    s += ''.join(random.choices(ascii_opts(), k=random.randint(0, 5)))
    s += pref
    s += ''.join(random.choices(pat_opts(), k=tail))
    s += ''.join(random.choices(ascii_opts(), k=random.randint(0, 5)))
    return s

import sys
pref = sys.argv[1]
tail = int(sys.argv[2])
print(gen_one(pref, tail))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment