Created
May 28, 2026 06:56
-
-
Save fzyzcjy/7e2fce6287badbb672ce93bec4fe9745 to your computer and use it in GitHub Desktop.
Mechanical refactor transform: extract release_req / retract_all as module-level free functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """Reproducible transform for: extract release_req / retract_all as module-level free functions. | |
| Run from the sglang repo root: | |
| python3 /tmp/transform_release_req_free_func.py | |
| """ | |
| import sys | |
| from pathlib import Path | |
| sys.path.append(".claude/skills/mechanical-refactor-verify") | |
| from mechanical_refactor_verify_utils import ( # noqa: E402 | |
| git_add_and_commit, | |
| verify_mechanical_refactor, | |
| ) | |
| BASE_COMMIT = "a34e245c12" | |
| TARGET_COMMIT = "cb5d1b88a9" | |
| SRC_REL = "python/sglang/srt/managers/schedule_batch.py" | |
| ANCHOR = "@dataclasses.dataclass\nclass ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):" | |
| FREE_FUNCTIONS = '''def release_req( | |
| *, | |
| req: Req, | |
| remaing_req_count: int, | |
| server_args: ServerArgs, | |
| req_to_token_pool: ReqToTokenPool, | |
| token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, | |
| tree_cache: BasePrefixCache, | |
| hisparse_coordinator: Optional[HiSparseCoordinator], | |
| ) -> None: | |
| if hisparse_coordinator is not None and not req.finished(): | |
| hisparse_coordinator.retract_req(req) | |
| if server_args.disaggregation_mode == "decode": | |
| req.offload_kv_cache(req_to_token_pool, token_to_kv_pool_allocator) | |
| # TODO (csy): for preempted requests, we may want to insert into the tree | |
| release_kv_cache(req, tree_cache, is_insert=False) | |
| # NOTE(lsyin): we should use the newly evictable memory instantly. | |
| num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get() | |
| evict_from_tree_cache(tree_cache, num_tokens) | |
| req.reset_for_retract() | |
| def retract_all( | |
| *, | |
| reqs: List[Req], | |
| server_args: ServerArgs, | |
| req_to_token_pool: ReqToTokenPool, | |
| token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, | |
| tree_cache: BasePrefixCache, | |
| hisparse_coordinator: Optional[HiSparseCoordinator], | |
| ) -> List[Req]: | |
| retracted_reqs = reqs | |
| for idx in range(len(reqs)): | |
| release_req( | |
| req=reqs[idx], | |
| remaing_req_count=len(reqs) - idx, | |
| server_args=server_args, | |
| req_to_token_pool=req_to_token_pool, | |
| token_to_kv_pool_allocator=token_to_kv_pool_allocator, | |
| tree_cache=tree_cache, | |
| hisparse_coordinator=hisparse_coordinator, | |
| ) | |
| return retracted_reqs | |
| ''' | |
| OLD_METHOD_RETRACT_ALL = ''' def retract_all(self, server_args: ServerArgs): | |
| retracted_reqs = self.reqs | |
| for idx in range(len(self.reqs)): | |
| self.release_req(idx, len(self.reqs) - idx, server_args) | |
| self.reqs = [] | |
| return retracted_reqs''' | |
| NEW_METHOD_RETRACT_ALL = ''' def retract_all(self, server_args: ServerArgs): | |
| retracted_reqs = retract_all( | |
| reqs=self.reqs, | |
| server_args=server_args, | |
| req_to_token_pool=self.req_to_token_pool, | |
| token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, | |
| tree_cache=self.tree_cache, | |
| hisparse_coordinator=self.hisparse_coordinator, | |
| ) | |
| self.reqs = [] | |
| return retracted_reqs''' | |
| OLD_METHOD_RELEASE_REQ = ''' def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): | |
| req = self.reqs[idx] | |
| if self.hisparse_coordinator is not None and not req.finished(): | |
| self.hisparse_coordinator.retract_req(req) | |
| if server_args.disaggregation_mode == "decode": | |
| req.offload_kv_cache( | |
| self.req_to_token_pool, self.token_to_kv_pool_allocator | |
| ) | |
| # TODO (csy): for preempted requests, we may want to insert into the tree | |
| release_kv_cache(req, self.tree_cache, is_insert=False) | |
| # NOTE(lsyin): we should use the newly evictable memory instantly. | |
| num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get() | |
| evict_from_tree_cache(self.tree_cache, num_tokens) | |
| req.reset_for_retract()''' | |
| NEW_METHOD_RELEASE_REQ = ''' def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): | |
| release_req( | |
| req=self.reqs[idx], | |
| remaing_req_count=remaing_req_count, | |
| server_args=server_args, | |
| req_to_token_pool=self.req_to_token_pool, | |
| token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, | |
| tree_cache=self.tree_cache, | |
| hisparse_coordinator=self.hisparse_coordinator, | |
| )''' | |
| def transform(dir_root: Path) -> None: | |
| src = dir_root / SRC_REL | |
| content = src.read_text() | |
| assert content.count(ANCHOR) == 1, "ScheduleBatch anchor must appear exactly once" | |
| content = content.replace(ANCHOR, FREE_FUNCTIONS + ANCHOR, 1) | |
| assert ( | |
| content.count(OLD_METHOD_RETRACT_ALL) == 1 | |
| ), "old ScheduleBatch.retract_all body must appear exactly once" | |
| content = content.replace(OLD_METHOD_RETRACT_ALL, NEW_METHOD_RETRACT_ALL, 1) | |
| assert ( | |
| content.count(OLD_METHOD_RELEASE_REQ) == 1 | |
| ), "old ScheduleBatch.release_req body must appear exactly once" | |
| content = content.replace(OLD_METHOD_RELEASE_REQ, NEW_METHOD_RELEASE_REQ, 1) | |
| src.write_text(content) | |
| git_add_and_commit( | |
| "Extract release_req and retract_all as module-level free functions", | |
| cwd=str(dir_root), | |
| ) | |
| if __name__ == "__main__": | |
| verify_mechanical_refactor( | |
| base_commit=BASE_COMMIT, | |
| target_commit=TARGET_COMMIT, | |
| transform=transform, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment