Skip to content

Instantly share code, notes, and snippets.

@ClementWalter
Last active December 28, 2024 17:03
Show Gist options
  • Save ClementWalter/5b75f2c8522beae3ae9a9f90d31edfcf to your computer and use it in GitHub Desktop.
Save ClementWalter/5b75f2c8522beae3ae9a9f90d31edfcf to your computer and use it in GitHub Desktop.
CASM constraints with polars rs
# %% Imports
from dataclasses import asdict
import polars as pl
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME
from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager
from starkware.cairo.lang.compiler.preprocessor.flow import ReferenceManager
from starkware.cairo.lang.compiler.program import Program
from starkware.cairo.lang.compiler.scoped_name import ScopedName
from starkware.cairo.lang.vm.cairo_runner import CairoRunner
# %% Dummy runner for linting
runner = CairoRunner(
program=Program(
prime=DEFAULT_PRIME,
data=[],
builtins=[],
main_scope=ScopedName(path=("__main__", "main")),
identifiers=IdentifierManager(),
reference_manager=ReferenceManager(),
hints={},
compiler_version=None,
)
)
# %% Get memory and trace from runner
memory = pl.DataFrame(
{
"address": runner.relocated_memory.data.keys(),
"value": [f"{v:016x}" for v in runner.relocated_memory.data.values()],
}
)
trace = pl.DataFrame([asdict(x) for x in runner.relocated_trace])
# %% Instructions table
instruction_table = (
pl.concat([trace["pc"], trace["pc"] + 1])
.unique()
.to_frame()
.join(memory, left_on="pc", right_on="address", how="left")
.with_columns(
instruction=pl.col("value").str.to_integer(base=16).fill_null(0),
off_dst=pl.col("value").str.slice(12, 4).str.to_integer(base=16).fill_null(0),
off_op0=pl.col("value").str.slice(8, 4).str.to_integer(base=16).fill_null(0),
off_op1=pl.col("value").str.slice(4, 4).str.to_integer(base=16).fill_null(0),
flags=(
pl.col("value")
.str.slice(0, 4)
.str.to_integer(base=16)
.fill_null(0)
.map_elements(lambda x: f"{x:016b}")
.map_elements(lambda x: [int(x[:i], 2) for i in range(1, 17)])
),
)
.with_columns(flags=pl.col("flags").list.to_struct(fields=lambda i: f"f_{15 - i}"))
.drop("value")
)
# %% Decoding assertion
constraint = pl.col("off_dst") + 2**16 * pl.col("off_op0") + 2**32 * pl.col(
"off_op1"
) + 2**48 * pl.col("f_0") == pl.col("instruction")
assert instruction_table.unnest("flags").select(valid=constraint)["valid"].all()
# %% Flags assertion
diff = (-2 * pl.col("value") + pl.col("value").shift(-1)).fill_null(0)
last_flag = pl.Series(range(len(instruction_table) * 16)) % 16
constraint = diff * (diff - 1) * ~(last_flag == 15) == 0
assert (
instruction_table["flags"]
.struct.unnest()
.transpose()
.unpivot()
.select(valid=constraint)["valid"]
.all()
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment