Skip to content

Instantly share code, notes, and snippets.

@pjt33
Last active April 11, 2026 08:25
Show Gist options
  • Select an option

  • Save pjt33/685000a7b90851364254a3dd71c9027f to your computer and use it in GitHub Desktop.

Select an option

Save pjt33/685000a7b90851364254a3dd71c9027f to your computer and use it in GitHub Desktop.
BPSW
from typing import Dict, List, Iterable, Optional
from math import gcd
def is_probable_prime(n: int) -> bool:
"""
Pseudo-primality test using Baillie-Pomerance-Selfridge-Wagstaff
"""
return _is_probable_prime_miller_rabin(n, 2) and _is_probable_prime_selfridge(n)
def _is_probable_prime_miller_rabin(n, a):
# If n is prime then a^{n-1} = 1 (mod n) - the Fermat test.
# Miller-Rabin strengthens the test by looking at successive square roots.
if a < 2:
raise ValueError
if n < 2:
return False
if (n & 1) == 0:
return n == 2
# Decompose n - 1 as (2^r)s.
nm1 = n - 1
R = nm1 & -nm1
s = nm1 // R
apow = pow(a, s, n)
if apow == 1 or apow == nm1:
return True
# Check a^{(2^t)s} mod n for 0 <= t < r
while R > 2:
R >>= 1
apow = apow * apow % n
if apow == nm1:
return True
return False
def _is_probable_prime_selfridge(n):
if n < 2:
return False
if (n & 1) == 0:
return n == 2
# Selfridge: find the first d in the sequence 5, -7, 9, -11, 13, ... for which that jacobi(d,n) = -1.
# Then use p = 1 and q = (1 - d) / 4 in the Lucas pseudoprime test.
d = 5
while True:
j = jacobi(d, n)
# If j == 0, gcd(d, n) > 1 so either n | d or n is composite...
# Because we skipped d = 3, we can have d == n for composite n if n == 9.
# Otherwise n | d => n == d and no smaller odd number (except 3) has non-trivial gcd with n.
if j == 0:
return abs(d) == n and d != 9
if j == -1:
break
# If n is a perfect square, we would loop until d = +/- n, and we want to avoid that.
if d == 13 and isqrt(n) ** 2 == n:
return False
d = (2 if d < 0 else -2) - d
q = (1 - d) >> 2
g = gcd(abs(d * q << 1), n)
if g != 1 and g != n:
return False
half = (n + 1) >> 1 # reciprocal of 2 (mod n)
u_i, v_i, q_i = 1, 1, q
bits = list(to_base(n + 1, 2))
for bit_j in bits[1:]:
# Double
u_i = u_i * v_i % n
v_i = (v_i * v_i - (q_i << 1)) % n
q_i = q_i * q_i % n
# Increment if required
if bit_j == 1:
# Not necessary to reduce, because the results here are at most a couple of bits more rather than twice as many
avg = (u_i + v_i) * half % n
v_i = avg - (q << 1) * u_i
u_i = avg
q_i = q_i * q
return u_i == 0
def jacobi(n: int, k: int) -> int:
if k <= 0 or (k & 1) == 0:
raise ValueError
n %= k
t = 1
while n:
while (n & 1) == 0:
n >>= 1;
r = k & 7
if r == 3 or r == 5:
t = -t
n, k = k, n
if (n & k & 3) == 3:
t = -t
n %= k
return t if k == 1 else 0
def to_base(n: int, b: int) -> Iterable[int]:
"""Extracts the base-b digits of n in big-endian order (most significant first)"""
digits = []
while n:
n, lsd = divmod(n, b)
digits.append(lsd)
return reversed(digits)
def isqrt(n) -> int:
"""Integer square root. Takes the floor"""
return ikth_root(2, n)
def ikth_root(k: int, n: int) -> int:
if n in [0, 1]:
return n
# TODO Make this robust for n large enough to overflow the float
try:
s = int(n ** (1. / k))
except OverflowError:
s = 1 << (n.bit_length() // k)
# May need polishing with Newton-Raphson
# Looking for root of f(x) = x^k - n = 0
# Therefore x' = x - f(x)/f'(x) = x - (x^k - n) / (kx^{k-1}) = ((k-1)x^k + n) / (kx^{k-1})
while True:
next_s = ((k-1)*s**k + n) // (k * s**(k-1))
if abs(next_s - s) < 10:
break
s = next_s
while (s + 1) ** k <= n:
s += 1
while s ** k > n:
s -= 1
return s
print([p for p in range(2, 100) if is_probable_prime(p)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment