Skip to content

Instantly share code, notes, and snippets.

@fzyzcjy
Created May 28, 2026 06:56
Show Gist options
  • Select an option

  • Save fzyzcjy/7e2fce6287badbb672ce93bec4fe9745 to your computer and use it in GitHub Desktop.

Select an option

Save fzyzcjy/7e2fce6287badbb672ce93bec4fe9745 to your computer and use it in GitHub Desktop.
Mechanical refactor transform: extract release_req / retract_all as module-level free functions
#!/usr/bin/env python3
"""Reproducible transform for: extract release_req / retract_all as module-level free functions.
Run from the sglang repo root:
python3 /tmp/transform_release_req_free_func.py
"""
import sys
from pathlib import Path
sys.path.append(".claude/skills/mechanical-refactor-verify")
from mechanical_refactor_verify_utils import ( # noqa: E402
git_add_and_commit,
verify_mechanical_refactor,
)
BASE_COMMIT = "a34e245c12"
TARGET_COMMIT = "cb5d1b88a9"
SRC_REL = "python/sglang/srt/managers/schedule_batch.py"
ANCHOR = "@dataclasses.dataclass\nclass ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):"
FREE_FUNCTIONS = '''def release_req(
*,
req: Req,
remaing_req_count: int,
server_args: ServerArgs,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tree_cache: BasePrefixCache,
hisparse_coordinator: Optional[HiSparseCoordinator],
) -> None:
if hisparse_coordinator is not None and not req.finished():
hisparse_coordinator.retract_req(req)
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(req_to_token_pool, token_to_kv_pool_allocator)
# TODO (csy): for preempted requests, we may want to insert into the tree
release_kv_cache(req, tree_cache, is_insert=False)
# NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
evict_from_tree_cache(tree_cache, num_tokens)
req.reset_for_retract()
def retract_all(
*,
reqs: List[Req],
server_args: ServerArgs,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tree_cache: BasePrefixCache,
hisparse_coordinator: Optional[HiSparseCoordinator],
) -> List[Req]:
retracted_reqs = reqs
for idx in range(len(reqs)):
release_req(
req=reqs[idx],
remaing_req_count=len(reqs) - idx,
server_args=server_args,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
tree_cache=tree_cache,
hisparse_coordinator=hisparse_coordinator,
)
return retracted_reqs
'''
OLD_METHOD_RETRACT_ALL = ''' def retract_all(self, server_args: ServerArgs):
retracted_reqs = self.reqs
for idx in range(len(self.reqs)):
self.release_req(idx, len(self.reqs) - idx, server_args)
self.reqs = []
return retracted_reqs'''
NEW_METHOD_RETRACT_ALL = ''' def retract_all(self, server_args: ServerArgs):
retracted_reqs = retract_all(
reqs=self.reqs,
server_args=server_args,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tree_cache=self.tree_cache,
hisparse_coordinator=self.hisparse_coordinator,
)
self.reqs = []
return retracted_reqs'''
OLD_METHOD_RELEASE_REQ = ''' def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx]
if self.hisparse_coordinator is not None and not req.finished():
self.hisparse_coordinator.retract_req(req)
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator
)
# TODO (csy): for preempted requests, we may want to insert into the tree
release_kv_cache(req, self.tree_cache, is_insert=False)
# NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
evict_from_tree_cache(self.tree_cache, num_tokens)
req.reset_for_retract()'''
NEW_METHOD_RELEASE_REQ = ''' def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
release_req(
req=self.reqs[idx],
remaing_req_count=remaing_req_count,
server_args=server_args,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tree_cache=self.tree_cache,
hisparse_coordinator=self.hisparse_coordinator,
)'''
def transform(dir_root: Path) -> None:
src = dir_root / SRC_REL
content = src.read_text()
assert content.count(ANCHOR) == 1, "ScheduleBatch anchor must appear exactly once"
content = content.replace(ANCHOR, FREE_FUNCTIONS + ANCHOR, 1)
assert (
content.count(OLD_METHOD_RETRACT_ALL) == 1
), "old ScheduleBatch.retract_all body must appear exactly once"
content = content.replace(OLD_METHOD_RETRACT_ALL, NEW_METHOD_RETRACT_ALL, 1)
assert (
content.count(OLD_METHOD_RELEASE_REQ) == 1
), "old ScheduleBatch.release_req body must appear exactly once"
content = content.replace(OLD_METHOD_RELEASE_REQ, NEW_METHOD_RELEASE_REQ, 1)
src.write_text(content)
git_add_and_commit(
"Extract release_req and retract_all as module-level free functions",
cwd=str(dir_root),
)
if __name__ == "__main__":
verify_mechanical_refactor(
base_commit=BASE_COMMIT,
target_commit=TARGET_COMMIT,
transform=transform,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment