Skip to content

Instantly share code, notes, and snippets.

@yaoyaoding
Last active April 27, 2026 19:32
Show Gist options
  • Select an option

  • Save yaoyaoding/01bd6ca5169c2d52261c5e7d6c31408c to your computer and use it in GitHub Desktop.

Select an option

Save yaoyaoding/01bd6ca5169c2d52261c5e7d6c31408c to your computer and use it in GitHub Desktop.
Minimal repro: tvm-ffi Object aliasing across multiprocessing fork.
"""Minimal repro: tvm-ffi Object aliasing across multiprocessing fork.
Why we want each tvm-ffi C-handle to map to a *unique* Python wrapper
-------------------------------------------------------------------
A single C++ Object can currently be reached from Python by many distinct
wrapper instances — every ``obj.field`` / ``arr[i]`` access mints a fresh
``CObject`` wrapper around the same chandle. Same handle, but different
``id(...)``.
That's harmless when staying in-process (handle equality still works), but
it costs us at every Python-level identity boundary:
* **Pickle / multiprocessing.** Stock pickle dedupes by ``id()``. Two
fresh wrappers for the same handle look like two different objects, get
pickled twice, and on the unpickle side become two distinct C++ objects
with different chandles — i.e. structural aliasing in the IR is lost.
See ``worker`` vs ``worker1`` below.
* **Identity-keyed Python dicts / sets** (rewriter memos, name tables).
``d[var]`` works (``__hash__`` / ``__eq__`` are chandle-based), but any
``a is b`` shortcut fails on fresh wrappers — every visit looks "new",
which churns memos and breaks no-op short-circuits in rewriters.
* **Debugging / readability.** ``id(obj)`` doesn't correlate with
identity; printing the same node twice shows two different ids; and a
``repr`` cache keyed on the wrapper isn't reused even for the same
underlying object.
If we kept exactly *one* Python wrapper per chandle, alive as long as
the C++ object lives:
* ``a is b`` ⇔ same chandle (cheap identity test).
* Pickle's id-dedupe alone preserves aliasing across processes.
* Memo/cache lookups by wrapper hit on every reuse without needing a
custom chandle-keyed map.
* ``id()`` becomes a stable per-object label.
This script demonstrates the failure mode and the fix:
* ``worker`` receives ``(layout.axes, layout.offset)`` — distinct
wrappers per access. Worker side: chandles diverge.
* ``worker1`` receives ``((i, j), (i, j))`` — the *same* Python wrappers
shared between the two tuple slots. Worker side: chandles preserved,
because pickle's id-based dedupe got a chance to fire.
"""
from __future__ import annotations
import multiprocessing
from typing import Tuple
import tvm_ffi
from tvm_ffi.dataclasses import py_class
@py_class("repro.Expr", frozen=True, structural_eq="tree")
class Expr(tvm_ffi.Object):
pass
@py_class("repro.Var", frozen=True, structural_eq="var")
class Var(Expr):
name: str
@py_class("repro.Add", frozen=True, structural_eq="tree")
class Add(Expr):
a: Expr
b: Expr
@py_class("repro.Layout", frozen=True, structural_eq="tree")
class Layout(tvm_ffi.Object):
axes: tuple[Var, ...]
offset: Expr
def collect_vars(node) -> list:
out: list = []
if isinstance(node, Var):
out.append(node)
elif isinstance(node, Add):
out.extend(collect_vars(node.a))
out.extend(collect_vars(node.b))
return out
def describe(label: str, layout: Layout) -> None:
axes = [(v.name, hex(v.__chandle__()), id(v)) for v in layout.axes]
offset_vars = [(v.name, hex(v.__chandle__()), id(v)) for v in collect_vars(layout.offset)]
axes_h = {h for _, h, _ in axes}
offset_h = {h for _, h, _ in offset_vars}
aligned_chandle = axes_h <= offset_h
axes_pid = {p for _, _, p in axes}
offset_pid = {p for _, _, p in offset_vars}
aligned_pyid = axes_pid <= offset_pid
print(f"[{label}] aligned_chandle={aligned_chandle} aligned_pyid={aligned_pyid}")
print(f" axes (name, chandle, py id): {axes}")
print(f" offset_vars (name, chandle, py id): {offset_vars}")
def worker(payload) -> Tuple:
axes_tuple, offset = payload
axes = [(v.name, hex(v.__chandle__()), id(v)) for v in axes_tuple]
offset_vars = [(v.name, hex(v.__chandle__()), id(v)) for v in collect_vars(offset)]
return ("worker", axes, offset_vars)
def worker1(payload) -> Tuple:
"""Same as ``worker`` but the payload is a tuple of (axes, offset_vars),
where offset_vars holds the **same Python wrapper objects** as axes
(Python ``id()`` aligned). Tests whether pickle's id-based dedupe alone
is enough to keep chandles aligned on the worker side.
"""
axes_tuple, offset_vars_tuple = payload
axes = [(v.name, hex(v.__chandle__()), id(v)) for v in axes_tuple]
offset_vars = [(v.name, hex(v.__chandle__()), id(v)) for v in offset_vars_tuple]
return ("worker1", axes, offset_vars)
def report(label: str, axes, offset_vars) -> None:
axes_h = {h for _, h, _ in axes}
offset_h = {h for _, h, _ in offset_vars}
aligned_chandle = axes_h <= offset_h
axes_pid = {p for _, _, p in axes}
offset_pid = {p for _, _, p in offset_vars}
aligned_pyid = axes_pid <= offset_pid
print(f"[{label}] aligned_chandle={aligned_chandle} aligned_pyid={aligned_pyid}")
print(f" axes (name, chandle, py id): {axes}")
print(f" offset_vars (name, chandle, py id): {offset_vars}")
def main() -> None:
i = Var(name="i")
j = Var(name="j")
layout = Layout(axes=(i, j), offset=Add(a=i, b=j))
describe("parent", layout)
ctx = multiprocessing.get_context("fork")
# ----- worker: payload = (layout.axes, layout.offset) -----
# Each py_class field access hands out a *fresh* Python wrapper, so
# axes and offset hold separate Python wrappers (different id()) for
# the same chandle. Python pickle's id-based dedupe can NOT collapse
# them; tvm-ffi reduces each subtree independently → fresh make_object
# calls on the worker side → chandles diverge.
with ctx.Pool(1) as pool:
results = pool.map(worker, [(layout.axes, layout.offset)])
for label, axes, offset_vars in results:
report(label, axes, offset_vars)
# ----- worker1: payload = ((i, j), (i, j)) (same Python wrappers) -----
# Offset_vars references the SAME Python wrapper objects as axes
# (Python id() matches), so Python pickle's id-based dedupe should
# serialize each Var once. Question: does that survive into the
# worker's chandle space?
with ctx.Pool(1) as pool:
results = pool.map(worker1, [((i, j), (i, j))])
for label, axes, offset_vars in results:
report(label, axes, offset_vars)
if __name__ == "__main__":
main()
```
python /home/yaoyaod/repos/tilus/.claude/worktrees/tvm-ffi-refactor/scripts/repro_var_aliasing.py
[parent] aligned_chandle=True aligned_pyid=False
axes (name, chandle, py id): [('i', '0x39ee26a0', 134977494077008), ('j', '0x39f257b0', 134977494076960)]
offset_vars (name, chandle, py id): [('i', '0x39ee26a0', 134977494076960), ('j', '0x39f257b0', 134977494076864)]
[worker] aligned_chandle=False aligned_pyid=False
axes (name, chandle, py id): [('i', '0x39f12d20', 134977494338192), ('j', '0x39f168d0', 134977494338288)]
offset_vars (name, chandle, py id): [('i', '0x3802eae0', 134977494338288), ('j', '0x3802eb10', 134977494338240)]
[worker1] aligned_chandle=True aligned_pyid=True
axes (name, chandle, py id): [('i', '0x39f01f20', 134977493769456), ('j', '0x3802dc40', 134977494338288)]
offset_vars (name, chandle, py id): [('i', '0x39f01f20', 134977493769456), ('j', '0x3802dc40', 134977494338288)]
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment