Last active
January 15, 2024 17:01
-
-
Save Yu212/4000b235e118d965897035c55920701b to your computer and use it in GitHub Desktop.
Small Secret Exponent Attack
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import sqrt, ceil, gcd, comb, log, exp | |
import subprocess | |
import time | |
def gen(size, delta_list): | |
while True: | |
p = random_prime(2^(size//2), lbound=2^(size//2-1)) | |
q = random_prime(2^(size//2), lbound=2^(size//2-1)) | |
if gcd(p-1, q-1) == 2: | |
break | |
phi = (p-1)*(q-1) | |
d_list = [] | |
for delta in delta_list: | |
nbits = log(p*q)*delta | |
d = int(exp(nbits)) | |
while True: | |
if gcd(d, phi//2) == 1: | |
break | |
d -= 1 | |
d_list.append(d) | |
return p, q, d_list | |
def decrypt(n, e, c, y): | |
s = -2*y | |
p = (s+(s*s-4*n).isqrt())//2 | |
q = n//p | |
phi = (p-1)*(q-1) | |
d = pow(e, -1, phi) | |
return pow(c, d, n) | |
def ssea(algo, m, t, plain, key): | |
global n, e | |
p, q, d = key | |
n = p*q | |
phi = (p-1)*(q-1) | |
e = int(pow(d, -1, phi//2)) | |
A = (n+1)//2 | |
x = (e*d-1)//(phi//2) | |
y = -(p+q)//2 | |
z = x*y+1 | |
c = pow(plain, e, n) | |
print(f"{p = }") | |
print(f"{q = }") | |
print(f"{e = }") | |
print(f"{d = }") | |
print(f"{A = }") | |
print(f"{x = }") | |
print(f"{y = }") | |
assert (x * (A + y) + 1) % e == 0 | |
xx = int(e ** 0.292) | |
yy = int(e ** 0.5) | |
zz = xx * yy + 1 | |
ps = polynomials(m, *hybrid_param(t)) | |
lat = lattice(ps, xx, yy, zz) | |
print(f"LLL start: {len(lat)}") | |
start = time.time() | |
start = time.time() | |
if algo == "flatter": | |
mat = flatter(lat) | |
elif algo == "self": | |
mat = LLL(lat) | |
elif algo == "sage": | |
mat = matrix(ZZ, lat).LLL(delta=0.999) | |
else: | |
raise ValueError(algo) | |
print(f"{algo}, {sum(int(sum(v**2 for v in row)).bit_length() for row in mat) / len(lat):.0f}, {time.time() - start:.3f}s") | |
start = time.time() | |
degrees = list_degree(ps) | |
qs = [Polynomial({deg: coef//monomial(*deg)(xx, yy) for coef, deg in zip(vec, degrees)}).remove_z() for vec in mat] | |
x, y = solve(qs) | |
print(decrypt(n, e, c, y), f"{time.time() - start:.3f}s") | |
def herrmann_may_lattice(m, tau): # 0.292 | |
return polynomials(m, tau, 1) | |
def blomer_may_lattice(m, t): # 0.290 | |
return polynomials(m, 1, t/m) | |
def polynomials(m, tau, eta): | |
gs = [] | |
for u in range(ceil(m*(1-eta)), m+1): | |
for i in range(0, u+1): | |
gs.append((u-i, i)) | |
gs.sort(key=lambda g: (g[0]+g[1], g[1])) | |
hs = [] | |
for u in range(ceil(m*(1-eta)), m+1): | |
for i in range(1, ceil(tau*(u-m*(1-eta)))+1): | |
hs.append((i, u)) | |
hs.sort(key=lambda h: (h[1], h[0])) | |
return [g(m, i, j) for i, j in gs] + [h(m, i, u) for i, u in hs] | |
def list_degree(polynomials): | |
degrees = set() | |
for p in polynomials: | |
degrees.update(p.coef.keys()) | |
return sorted(degrees, key=lambda e: (e[1], e[0]+e[2], e[2])) | |
def lattice(polynomials, x, y, z): | |
n = len(polynomials) | |
lattice = [] | |
degrees = list_degree(polynomials) | |
for p in polynomials: | |
lattice.append([p.get_coef(*d) * x**d[0] * y**d[1] * z**d[2] for d in degrees]) | |
return lattice | |
def hybrid_param(t): | |
return 1-(2-2**0.5)*t, (6**0.5-2)*t+(3-6**0.5) | |
# Z + AX | |
def f(): | |
return monomial(Z=1) + monomial(X=1, coef=(n+1)//2) | |
# X^i * f^k * e^(m-k) | |
def g(m, i, k): | |
return (monomial(X=i, coef=e**(m-k)) * f() ** k).normalize() | |
# Y^i * f^u * e^(m-u) | |
def h(m, i, u): | |
return (monomial(Y=i, coef=e**(m-u)) * f() ** u).normalize() | |
def monomial(X=0, Y=0, Z=0, coef=1): | |
return Polynomial({(X, Y, Z): coef}) | |
class Polynomial: | |
def __init__(self, coef): | |
self.coef = coef | |
def __add__(self, other): | |
coef = self.coef.copy() | |
for k, v in other.coef.items(): | |
coef[k] = coef.get(k, 0) + v | |
return Polynomial(coef) | |
def __mul__(self, other): | |
coef = dict() | |
for k1, v1 in self.coef.items(): | |
for k2, v2 in other.coef.items(): | |
k = (k1[0]+k2[0], k1[1]+k2[1], k1[2]+k2[2]) | |
coef[k] = coef.get(k, 0) + v1 * v2 | |
return Polynomial(coef) | |
def __pow__(self, n): | |
if n == 0: | |
return monomial() | |
coef = Polynomial(self.coef.copy()) | |
for _ in range(1, n): | |
coef *= self | |
return coef | |
def __call__(self, x, y): | |
z = x * y + 1 | |
val = 0 | |
for k, v in self.coef.items(): | |
val += x ** k[0] * y ** k[1] * z ** k[2] * v | |
return val | |
def normalize(self): | |
coef = dict() | |
for k, v in self.coef.items(): | |
xy = min(k[0], k[1]) | |
for i in range(xy+1): | |
nk = (k[0]-xy, k[1]-xy, k[2]+i) | |
sign = 1 if i%2 == xy%2 else -1 | |
coef[nk] = coef.get(nk, 0) + sign * v * comb(xy, xy-i) | |
if coef[nk] == 0: | |
coef.pop(nk) | |
return Polynomial(coef) | |
def remove_z(self): | |
coef = dict() | |
for k, v in self.coef.items(): | |
for i in range(k[2]+1): | |
nk = (k[0]+i, k[1]+i, 0) | |
coef[nk] = coef.get(nk, 0) + v * comb(k[2], i) | |
if coef[nk] == 0: | |
coef.pop(nk) | |
return Polynomial(coef) | |
def monomials(self): | |
return sorted(self.coef.items(), key=lambda e: (e[0][1], e[0][0]+e[0][2], e[0][2])) | |
def get_coef(self, X=0, Y=0, Z=0): | |
return self.coef.get((X, Y, Z), 0) | |
def __str__(self): | |
s = [] | |
for k, v in self.monomials(): | |
t = [] | |
if v != 1: | |
t.append(str(v)) | |
for c, x in zip("XYZ", k): | |
if x == 1: | |
t.append(str(c)) | |
elif x >= 2: | |
t.append(f"{c}^{x}") | |
if len(t) == 0: | |
t.append(str(v)) | |
s.append("*".join(t)) | |
return " + ".join(s) | |
def solve(qs): | |
R.<x,y> = PolynomialRing(QQ) | |
ps = [poly(x, y) for poly in qs] | |
start = time.time() | |
left = 0 | |
right = len(ps) + 1 | |
while right - left > 1: | |
mid = (left + right) // 2 | |
H = Sequence(ps[:mid], R) | |
I = H.ideal() | |
dim = I.dimension() | |
if dim == -1: | |
right = mid | |
elif dim != 0: | |
left = mid | |
else: | |
root = I.variety(ring=ZZ)[0] | |
return root["x"], root["y"] | |
H = Sequence(ps[:left], R) | |
for i, h in enumerate(ps[left:]): | |
H.append(h) | |
I = H.ideal() | |
dim = I.dimension() | |
if dim == -1: | |
H.pop() | |
elif dim == 0: | |
root = I.variety(ring=ZZ)[0] | |
return root["x"], root["y"] | |
def gram_schmidt(bases): | |
n = len(bases) | |
bases = [vector(RR, base) for base in bases] | |
gs_bases = [zero_vector(RR, n) for _ in range(n)] | |
gs_coef = [[0.0] * n for _ in range(n)] | |
for i in range(n): | |
gs_bases[i] = bases[i] | |
for j in range(i): | |
gs_coef[i][j] = bases[i].dot_product(gs_bases[j]) / gs_bases[j].dot_product(gs_bases[j]) | |
gs_bases[i] -= gs_coef[i][j] * gs_bases[j] | |
return gs_bases, gs_coef | |
def LLL(basis, delta=0.75): | |
basis = [vector(ZZ, base) for base in basis] | |
n = len(basis) | |
gs_basis, gs_coef = gram_schmidt(basis) | |
gs_basis_dot = [base.dot_product(base) for base in gs_basis] | |
k = 1 | |
while k < n: | |
for j in reversed(range(k)): | |
if abs(gs_coef[k][j]) > 0.5: | |
r = round(gs_coef[k][j]) | |
basis[k] -= r * basis[j] | |
for i in range(n): | |
gs_coef[k][i] -= r * gs_coef[j][i] | |
gs_coef[k][j] %= 1 | |
if 0.5 < gs_coef[k][j]: | |
gs_coef[k][j] -= 1 | |
if gs_basis_dot[k] >= (delta - gs_coef[k][k-1] ** 2) * gs_basis_dot[k-1]: | |
k += 1 | |
else: | |
basis[k], basis[k-1] = basis[k-1], basis[k] | |
mu_prime = gs_coef[k][k-1] | |
b = gs_basis_dot[k] + mu_prime * gs_basis_dot[k-1] | |
gs_basis_dot[k] = gs_basis_dot[k] * gs_basis_dot[k-1] / b | |
gs_basis_dot[k-1] = b | |
for j in range(k-1): | |
gs_coef[k-1][j], gs_coef[k][j] = gs_coef[k][j], gs_coef[k-1][j] | |
for j in range(k+1, n): | |
t = gs_coef[j][k] | |
gs_coef[j][k] = gs_coef[j][k-1] - mu_prime * t | |
gs_coef[j][k-1] = t + gs_coef[k][k-1] * gs_coef[j][k] | |
k = max(1, k-1) | |
return basis | |
def flatter(bs): | |
lines = ["[" + " ".join(map(str, b)) + "]" for b in bs] | |
lattice = "[" + "\n".join(lines) + "]" | |
proc = subprocess.run("flatter", input=lattice, text=True, stdout=subprocess.PIPE) | |
result = proc.stdout.replace("[", "").replace("]", "").strip() | |
return [list(map(int, line.split(" "))) for line in result.split("\n")] | |
def manual(): | |
size, E, m, t, algo = input("input params (ex: \"64,0.292,15,1,flatter\"): ").split(",") | |
size, E, m, t = int(size), float(E), int(m), float(t) | |
p, q, d_list = gen(size, [E]) | |
d = d_list[0] | |
print(f"solve with param: E={E:.2f} real={log(d)/log(p*q):.4f}") | |
ssea(algo, m, t, 998244353, (p, q, d)) | |
def bench(): | |
delta_list = [float(i / 100) for i in range(10, 28)] | |
delta_list = list(reversed(delta_list)) | |
p, q, d_list = gen(64, delta_list) | |
for delta, d in zip(delta_list, d_list): | |
print(f"solve with param: E={delta:.2f} real={log(d)/log(p*q):.4f}") | |
ssea("flatter", 20, 0, 998244353, (p, q, d)) | |
print() | |
if __name__ == "__main__": | |
manual() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment