Last active
February 17, 2025 05:30
-
-
Save mjtb49/506ac15656bfad6f9307059257b07200 to your computer and use it in GitHub Desktop.
Earthcomputer requested that I write up a lcg discrete log solver for arbitrary lcgs. I have now done so, but lazily, there are several points where this could be improved and I am not entirely convinced by the approach. This solver assumes some sort of factorization is possible - in particular it needs to factor both phi(m) and m at some point …
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import sympy | |
import sympy as sp | |
import random | |
class LCG: | |
def __init__(self, a, b, m): | |
self.a = a % m | |
self.b = b % m | |
self.m = m | |
def __matmul__(self, other): | |
assert self.m == other.m | |
return LCG(self.a * other.a, self.a * other.b + self.b, self.m) | |
def __pow__(self, power, modulo=None): | |
result = LCG(1, 0, self.m) | |
current = self | |
while power != 0: | |
if power % 2 == 1: | |
result @= current | |
current @= current | |
power //= 2 | |
return result | |
def apply(self, seed): | |
return (self.a * seed + self.b) % self.m | |
# Find the p-adic valuation of a number a which we know up to an integer multiple of p^e | |
def vp(a, p, e): | |
a %= p ** e | |
if a == 0: | |
return e # "properly" this should be infinity. To avoid defining infinity is only reason why this is here | |
result = 0 | |
while a % p == 0: | |
result += 1 | |
a //= p | |
return result | |
def ord_mod_odd_prime(a, p): | |
assert sp.isprime(p) | |
a %= p | |
# I think order finding is roughly as hard as factoring p-1? But this might be removable. | |
N = p-1 | |
factors = list(sp.factorint(N).items()) | |
alpha = [pow(a, N//(q**e), p) for q, e in factors] | |
result = 1 | |
for i in range(len(factors)): | |
q_i = 1 | |
alpha_i = alpha[i] | |
while alpha_i != 1: | |
alpha_i = pow(alpha_i, factors[i][0], p) | |
q_i *= factors[i][0] | |
result *= q_i | |
assert result <= N | |
return result | |
def ord_mod_prime_power(a, p, e): | |
if e == 0: | |
return 1 | |
assert a % p != 0 | |
# For odd primes p | |
# the multiplicative group of Z_p is the product of the roots of unity \mu_p and U = 1 + pZ_p | |
# For 2 it is the product of U = 1 + 4Z_p and {+/-1} | |
# In the former case a projection operator onto U is given by exponentiating by p-1 | |
# and a projection onto \mu is given by reducing mod p. | |
# For 2 a projection onto U is given by negation if a = 3 mod 4, else the identity. | |
# and onto {+/- 1} is given by -1 if a = 3 mod 4, else 1. | |
# Handle case of p = 2 | |
if p == 2: | |
if e == 1: | |
return 1 | |
mu_order = 1 | |
if a % 4 == 3: | |
mu_order = 2 | |
a = p**e-a | |
U_order = p**(e-vp(a-1, p, e)) | |
return sympy.lcm(mu_order, U_order) | |
# Handle general case | |
mu_order = ord_mod_odd_prime(a, p) | |
U_order = p**(e-vp(pow(a, p-1, p**e)-1, p, e)) | |
return sympy.lcm(mu_order, U_order) | |
def bs_gs(a, target, p): | |
target %= p | |
a %= p | |
assert sp.isprime(p) and a % p != 0 | |
step_size = 1 + math.isqrt((p-1)-1) | |
giant_step = pow(a, -step_size, p) | |
baby_steps = {} | |
current = 1 | |
for i in range(step_size): | |
baby_steps[current] = i | |
current = a * current % p | |
for i in range(step_size): | |
if target in baby_steps: | |
return baby_steps[target] + step_size * i | |
target = target * giant_step % p | |
return None | |
def log_p(z, p, e): | |
xn = 1-z # 1 - xn = z | |
vpx = vp(xn, p, e) | |
assert vpx >= (2 if p == 2 else 1) | |
result = 0 | |
n = 1 | |
# Valuation of x^n = n * vpx, we want to pick n large enough so m * vp(x) - vp(m) >= e for all m >= n | |
# m * vp(x) - vp(m) >= m * vp(x) - log(m) / log(p) >= n * vp(x) - log(n) / log(p) | |
# Subtract 1 since I'm a coward | |
while e > n * vpx - math.log(n, p) - 1: | |
vpn = vp(n, p, e) | |
result -= xn // (p ** vpn) * pow(n // (p ** vpn), -1, p**e) | |
xn *= (1-z) | |
xn %= p ** (e + e) # This bound remains quite lazy. The correct bound is the max of ceil(e + math.log(n, p)) over n in the loop | |
n += 1 | |
result %= p**e | |
return result | |
def dist_mod_prime_power(a, b, p, e, x, y): | |
assert sp.gcd(a, p) == 1 | |
a %= p**e | |
b %= p**e | |
x %= p**e | |
y %= p**e | |
if a == 1: | |
a = p**e + 1 | |
d = (a-1) * x + b | |
n = (a-1) * y + b | |
e += vp(a - 1, p, e) | |
vpd = vp(d, p, e) | |
if vpd != vp(n, p, e): | |
return None | |
d //= p ** vpd | |
n //= p ** vpd | |
e -= vpd | |
target = (n * pow(d, -1, p ** e)) % (p ** e) | |
if e == 0: | |
return 0, 1 | |
order = ord_mod_prime_power(a, p, e) | |
# Want to solve a^k = target mod p**e | |
# solutions k determined up to value of order | |
# first solve mod p if p is odd, or mod 4 if p is 2 | |
# TODO, some costly parts of bs_gs can be computed from already known value of "order", or alternately order can be inferred from later work | |
if p % 2 == 1: | |
k0 = bs_gs(a, target, p) | |
if k0 is None: | |
return None | |
if e == 1: | |
return k0 % order, order | |
# print(f"Here with p = {p} {a % 5} {target % 5}") | |
# print(f"And here with p = {p}") | |
# We now know a^((p-1)k1 + k0) = target | |
target = target * pow(a, -k0, p ** e) % p ** e | |
a = pow(a, p-1, p**e) | |
loga = log_p(a, p, e + 2) % p ** e | |
logt = log_p(target, p, e + 2) % p ** e | |
if loga == logt == 0: | |
return k0 % order, order | |
while loga % p == 0: | |
if logt % p != 0: | |
return None | |
loga //= p | |
logt //= p | |
result = (pow(loga, -1, p**e) * logt % p**e) * (p-1) + k0 | |
return result % order, order | |
else: | |
# Now looking to solve a^k = target mod 2^e | |
# Unlike the odd prime case, the group here is not cyclic | |
# logarithm lets me solve a^k = t if a, t are both 1 mod 4. After that I must check if the sign is correct | |
# First compute the constraint coming from the sign | |
sign_residue = 0 | |
sign_modulus = 1 | |
if e <= 1: | |
return 0, 1 | |
if target % 4 == 3: | |
if a % 4 == 1: | |
return None | |
else: | |
sign_modulus = 2 | |
sign_residue = 1 | |
elif target % 4 == 1: | |
if a % 4 == 3: | |
sign_modulus = 2 | |
sign_residue = 0 | |
# Solve it in the U factor. | |
# Project a and target to U | |
a = a if a % 4 == 1 else p**e - a | |
target = target if target % 4 == 1 else p**e - target | |
loga = log_p(a, p, e) % p**e | |
logt = log_p(target, p, e) % p**e | |
if loga == logt == 0: # If this occurs either +/-a = +/-t = 1, | |
return sign_residue % order, order | |
while loga % p == 0: | |
if logt % p != 0: | |
return None | |
loga //= p | |
logt //= p | |
result = pow(loga, -1, p**e) * logt % p**e | |
if result % sign_modulus != sign_residue: | |
return None | |
return result % order, order | |
def solve_congruences(congruences): | |
if len(congruences) == 1: | |
return congruences[0] | |
a, b, *congruences = congruences | |
res_1, mod_1 = a | |
res_2, mod_2 = b | |
gcd = sp.gcd(mod_1, mod_2) | |
if (res_1 - res_2) % gcd != 0: | |
return None | |
u = pow(mod_1 // gcd, -1, mod_2 // gcd) if mod_2 // gcd != 1 else 0 | |
new_mod = sp.lcm(mod_1, mod_2) | |
new_res = (res_1 - mod_1 * u * (res_1 - res_2) // gcd) % new_mod | |
congruences.append((new_res % new_mod, new_mod)) | |
assert new_res % mod_1 == res_1 % mod_1 | |
assert new_res % mod_2 == res_2 % mod_2 | |
return solve_congruences(congruences) | |
# Returns (a,b) such that a + bk calls to lcg starting at x results in y for k = 0,1,2,... | |
# If no amount of calls from x reaches y, returns None. | |
# If the generator reaches y only once after being called with x, then b is returned as 0, otherwise b is minimal and non-zero. | |
# a is always minimal, but may not be reduced modulo b due to some nilpotence | |
def distance(lcg, x, y): | |
m = lcg.m | |
a = lcg.a | |
b = lcg.b | |
assert m != 0 | |
x %= m | |
y %= m | |
distance_to_constant = 0 | |
# handle primes dividing both a and m. | |
if sp.gcd(a, m) != 1: | |
# For each prime p dividing both a and m, we must figure out what power of this prime divides m. | |
m_bad = 1 | |
while sp.gcd(a, m) != 1: | |
m_bad *= sp.gcd(a, m) | |
m //= sp.gcd(a, m) | |
eventual_constant = x | |
while eventual_constant % m_bad != lcg.apply(eventual_constant) % m_bad: | |
if eventual_constant == y: # the generator only assumes y once before falling into a pattern not containing y | |
return distance_to_constant, 0 | |
distance_to_constant += 1 | |
eventual_constant = lcg.apply(eventual_constant) | |
if eventual_constant % m_bad != y % m_bad: | |
# the generator falls into a constant modulo m_bad that is incompatible with y, and it never reaches y before this. | |
return None | |
assert sp.gcd(a, m) == 1 | |
if m == 1: | |
return distance_to_constant, 1 | |
else: | |
# The multiplicative group mod m is the product of the multiplicative groups mod p^e | |
factors = list(sp.factorint(m).items()) | |
congruences = [] | |
for p, e in factors: | |
d = dist_mod_prime_power(a, b, p, e, x, y) | |
# print(p, e, d) | |
if d is None: | |
return None | |
congruences.append(d) | |
# print(congruences) | |
# print(m) | |
solution = solve_congruences(congruences) | |
if solution is None: | |
return None | |
r1, r2 = solution | |
assert r1 % r2 == r1 | |
assert r2 != 0 | |
while r1 < distance_to_constant: | |
r1 += r2 | |
return r1, r2 | |
def main(): | |
bound = 100 | |
# Test randomized solvable cases | |
for i in range(100000): | |
print(i) | |
m = random.randint(1, bound-1) | |
a = random.randint(0, m-1) | |
b = random.randint(0, m-1) | |
start = random.randint(0, m-1) # Whoops, this includes both endpoints | |
dist = random.randint(0, m-1) | |
# while sp.gcd(a, m) != 1: | |
# a = random.randint(1, bound) | |
# a, b, m, start, dist = 685, 940, 343, 869, 186 # (2, 2) | |
lcg = LCG(a, b, m) | |
target = (lcg**dist).apply(start) | |
d = distance(lcg, start, target) | |
# print(d) | |
if d is not None: | |
if (lcg ** (d[1] + d[0])).apply(start) != (lcg ** (d[0])).apply(start): | |
print("Generator doesn't seem to have eventual period d[1]") | |
print() | |
print(a, b, m, start, target, dist, d, (lcg ** d[1]).apply(start)) | |
return | |
if (lcg ** (d[0])).apply(start) != target % m: | |
print("Logarithm computed incorrectly") | |
print() | |
print(a, b, m, start, target, dist, d, (lcg ** d[1]).apply(start)) | |
return | |
if d[1] == 0: | |
if dist != d[0]: | |
print("Logarithm didn't find target small distance") | |
print(a, b, m, start, target, dist, d) | |
return | |
elif dist % d[1] != d[0] % d[1] or dist < d[0]: | |
print("Logarithm didn't find target distance") | |
print(a, b, m, start, target, dist, d) | |
return | |
if d[0] > d[1] > 0 and (lcg ** (d[0])).apply(start) == (lcg ** (d[0] - d[1])).apply(start): | |
print("d[0] not minimal") | |
print(a, b, m, start, target, dist, d) | |
return | |
else: | |
print("Uh oh") | |
print(a, b, m, start, target, dist, d) | |
return | |
# Test randomized unsolvable cases | |
num_tested = 0 | |
while num_tested < 100000: | |
m = random.randint(1, bound-1) | |
a = random.randint(0, m-1) | |
b = random.randint(0, m-1) | |
start = random.randint(0, m-1) # Whoops, this includes both endpoints | |
target = random.randint(0, m-1) | |
lcg = LCG(a, b, m) | |
if distance(lcg, start, target) is None: | |
num_tested += 1 | |
print(num_tested) | |
for j in range(m): | |
if start == target: | |
print(f"Could not solve a={a} b={b} m={m} start={start} target={target}. Expected {j}") | |
return | |
start = lcg.apply(start) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment