Last active
April 11, 2026 08:25
-
-
Save pjt33/685000a7b90851364254a3dd71c9027f to your computer and use it in GitHub Desktop.
BPSW
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Dict, List, Iterable, Optional | |
| from math import gcd | |
| def is_probable_prime(n: int) -> bool: | |
| """ | |
| Pseudo-primality test using Baillie-Pomerance-Selfridge-Wagstaff | |
| """ | |
| return _is_probable_prime_miller_rabin(n, 2) and _is_probable_prime_selfridge(n) | |
| def _is_probable_prime_miller_rabin(n, a): | |
| # If n is prime then a^{n-1} = 1 (mod n) - the Fermat test. | |
| # Miller-Rabin strengthens the test by looking at successive square roots. | |
| if a < 2: | |
| raise ValueError | |
| if n < 2: | |
| return False | |
| if (n & 1) == 0: | |
| return n == 2 | |
| # Decompose n - 1 as (2^r)s. | |
| nm1 = n - 1 | |
| R = nm1 & -nm1 | |
| s = nm1 // R | |
| apow = pow(a, s, n) | |
| if apow == 1 or apow == nm1: | |
| return True | |
| # Check a^{(2^t)s} mod n for 0 <= t < r | |
| while R > 2: | |
| R >>= 1 | |
| apow = apow * apow % n | |
| if apow == nm1: | |
| return True | |
| return False | |
| def _is_probable_prime_selfridge(n): | |
| if n < 2: | |
| return False | |
| if (n & 1) == 0: | |
| return n == 2 | |
| # Selfridge: find the first d in the sequence 5, -7, 9, -11, 13, ... for which that jacobi(d,n) = -1. | |
| # Then use p = 1 and q = (1 - d) / 4 in the Lucas pseudoprime test. | |
| d = 5 | |
| while True: | |
| j = jacobi(d, n) | |
| # If j == 0, gcd(d, n) > 1 so either n | d or n is composite... | |
| # Because we skipped d = 3, we can have d == n for composite n if n == 9. | |
| # Otherwise n | d => n == d and no smaller odd number (except 3) has non-trivial gcd with n. | |
| if j == 0: | |
| return abs(d) == n and d != 9 | |
| if j == -1: | |
| break | |
| # If n is a perfect square, we would loop until d = +/- n, and we want to avoid that. | |
| if d == 13 and isqrt(n) ** 2 == n: | |
| return False | |
| d = (2 if d < 0 else -2) - d | |
| q = (1 - d) >> 2 | |
| g = gcd(abs(d * q << 1), n) | |
| if g != 1 and g != n: | |
| return False | |
| half = (n + 1) >> 1 # reciprocal of 2 (mod n) | |
| u_i, v_i, q_i = 1, 1, q | |
| bits = list(to_base(n + 1, 2)) | |
| for bit_j in bits[1:]: | |
| # Double | |
| u_i = u_i * v_i % n | |
| v_i = (v_i * v_i - (q_i << 1)) % n | |
| q_i = q_i * q_i % n | |
| # Increment if required | |
| if bit_j == 1: | |
| # Not necessary to reduce, because the results here are at most a couple of bits more rather than twice as many | |
| avg = (u_i + v_i) * half % n | |
| v_i = avg - (q << 1) * u_i | |
| u_i = avg | |
| q_i = q_i * q | |
| return u_i == 0 | |
| def jacobi(n: int, k: int) -> int: | |
| if k <= 0 or (k & 1) == 0: | |
| raise ValueError | |
| n %= k | |
| t = 1 | |
| while n: | |
| while (n & 1) == 0: | |
| n >>= 1; | |
| r = k & 7 | |
| if r == 3 or r == 5: | |
| t = -t | |
| n, k = k, n | |
| if (n & k & 3) == 3: | |
| t = -t | |
| n %= k | |
| return t if k == 1 else 0 | |
| def to_base(n: int, b: int) -> Iterable[int]: | |
| """Extracts the base-b digits of n in big-endian order (most significant first)""" | |
| digits = [] | |
| while n: | |
| n, lsd = divmod(n, b) | |
| digits.append(lsd) | |
| return reversed(digits) | |
| def isqrt(n) -> int: | |
| """Integer square root. Takes the floor""" | |
| return ikth_root(2, n) | |
| def ikth_root(k: int, n: int) -> int: | |
| if n in [0, 1]: | |
| return n | |
| # TODO Make this robust for n large enough to overflow the float | |
| try: | |
| s = int(n ** (1. / k)) | |
| except OverflowError: | |
| s = 1 << (n.bit_length() // k) | |
| # May need polishing with Newton-Raphson | |
| # Looking for root of f(x) = x^k - n = 0 | |
| # Therefore x' = x - f(x)/f'(x) = x - (x^k - n) / (kx^{k-1}) = ((k-1)x^k + n) / (kx^{k-1}) | |
| while True: | |
| next_s = ((k-1)*s**k + n) // (k * s**(k-1)) | |
| if abs(next_s - s) < 10: | |
| break | |
| s = next_s | |
| while (s + 1) ** k <= n: | |
| s += 1 | |
| while s ** k > n: | |
| s -= 1 | |
| return s | |
| print([p for p in range(2, 100) if is_probable_prime(p)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment