Created
October 25, 2024 12:57
-
-
Save maurges/090db3c6ec66795b396370ffe79bc4fa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env runhaskell | |
{-# LANGUAGE LambdaCase #-} | |
import GHC.Stack (HasCallStack) | |
import Data.List (partition) | |
import Data.Foldable (foldl') | |
import Numeric (showHex) | |
import System.Environment (getArgs) | |
import Text.Read (readMaybe) | |
main = getArgs >>= \case | |
[] -> usage | |
["--help"] -> usage | |
"--stark" : "--interpolate" : ixs : args -> assembleInterpolate m_stark ixs args | |
"--interpolate" : ixs : args -> assembleInterpolate m_k256 ixs args | |
"--stark" : args -> assembleAdditive m_stark args | |
args -> assembleAdditive m_k256 args | |
where usage = putStrLn "Usage: ./assemble-presigs.hs [--stark] [--interpolate I1,I2...] PRESIG1 PRESIG2 ..." | |
assembleAdditive m args = do | |
let (r_hexs, s_hexs) = unzip . map (splitAt 64) $ args | |
let r_hex = head r_hexs | |
if not . all (== r_hex) $ r_hexs | |
then error "Mismatching rs" | |
else pure () | |
ss <- unwrap $ traverse (readMaybe . (<>) "0x") s_hexs :: IO [Integer] | |
let s' = sumMod m ss | |
let s = min s' (m - s') | |
putStrLn $ r_hex <> hex s | |
assembleInterpolate m preimages args = do | |
sharePreimages <- unwrap . traverse readMaybe . splitBy ',' $ preimages :: IO [Integer] | |
let (r_hexs, s_hexs) = unzip . map (splitAt 64) $ args | |
let r_hex = head r_hexs | |
if not . all (== r_hex) $ r_hexs | |
then error "Mismatching rs" | |
else pure () | |
ss <- unwrap $ traverse (readMaybe . (<>) "0x") s_hexs :: IO [Integer] | |
-- interpolate at zero of degree with values known at sharePreimages (giving the degree) the ss values | |
let s' = interpolateAtZero m sharePreimages ss | |
let s = min s' (m - s') | |
putStrLn $ r_hex <> leftpad 64 '0' (hex s) | |
hex = flip showHex "" | |
leftpad n c l = replicate (n - length l) c <> l | |
m_k256 = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 :: Integer | |
m_stark = 0x0800000000000010ffffffffffffffffb781126dcae7b2321e66a241adc64d2f :: Integer | |
sumMod m = foldl' (\a b -> (a + b) `mod` m) 0 | |
prodMod m = foldl' (\a b -> (a * b) `mod` m) 1 | |
interpolateAtZero | |
:: Integer -- ^ Modulus | |
-> [Integer] -- ^ Share preimages | |
-> [Integer] -- ^ Sharings | |
-> Integer | |
interpolateAtZero m preims sharings = | |
let coefs = [lagrangeCoef m 0 j preims | j <- zipWith (curry fst) [0, 1 ..] preims] | |
in foldl' (\acc (sharing, lambda) -> (acc + lambda * sharing) `mod` m) 0 $ zip sharings coefs | |
lagrangeCoef | |
:: Integer -- ^ Modulus | |
-> Integer -- ^ Point of interpolation | |
-> Int -- ^ Index of @xs@ which gives the point we're calculating the coefficient at | |
-> [Integer] -- ^ Points of known polynomial values | |
-> Integer | |
lagrangeCoef m x j xs = | |
let (xj, xsNoj) = takeout j xs | |
minus = modMinus m | |
nom = prodMod m $ map (\xm -> x `minus` xm) xsNoj | |
denom = prodMod m $ map (\xm -> xj `minus` xm) xsNoj | |
in (nom * invert m denom) `mod` m | |
where | |
takeout index vals = case partition (\t -> fst t == index) $ zip [0,1..] vals of | |
([(_, val)], rest) -> (val, map snd rest) | |
([], _) -> error "Index out of bounds" | |
unwrap :: HasCallStack => Maybe a -> IO a | |
unwrap (Just x) = pure x | |
unwrap Nothing = error "Unwrap: Nothing" | |
splitBy _ [] = [] | |
splitBy e xs = | |
let (i, t) = break (== e) xs | |
in i : splitBy e (tail' t) | |
where | |
tail' [] = [] | |
tail' (a:as) = as | |
-- | Assumes all values are already in modulus | |
modMinus :: Integer -> Integer -> Integer -> Integer | |
modMinus m a b | |
| a >= b = a - b | |
| otherwise = m + (a - b) | |
invert :: Integer -> Integer -> Integer | |
invert m val = | |
let r = fst $ extendedGcd val m | |
in if r < 0 then m + r else r | |
where | |
extendedGcd a b = go a b 1 0 where | |
go old_r 0 old_s _ = (old_s, (old_r - old_s * a) `div` b) | |
go old_r r old_s s = | |
let q = old_r `quot` r | |
in go | |
r | |
(old_r - q * r) | |
s | |
(old_s - q * s) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment