Created
April 6, 2026 01:55
-
-
Save JosephJoshua/8aeb2ac5032d457d7dd66352a4bfbe0c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Convert HuggingFace weights.pt → _object.ckpt for eval.py | |
| Usage: | |
| # Download weights first: | |
| # pip install huggingface_hub | |
| # huggingface-cli download quentinll/lewm-pusht --local-dir hf_pusht | |
| python convert_hf_to_object.py \ | |
| --weights hf_pusht/weights.pt \ | |
| --config hf_pusht/config.json \ | |
| --output $STABLEWM_HOME/pusht/lewm_object.ckpt | |
| """ | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import torch | |
| import stable_pretraining as spt | |
| from jepa import JEPA | |
| from module import ARPredictor, Embedder, MLP | |
| def build_model(config: dict) -> JEPA: | |
| enc_cfg = config["encoder"] | |
| encoder = spt.backbone.utils.vit_hf( | |
| enc_cfg["size"], | |
| patch_size=enc_cfg["patch_size"], | |
| image_size=enc_cfg["image_size"], | |
| pretrained=enc_cfg["pretrained"], | |
| use_mask_token=enc_cfg["use_mask_token"], | |
| ) | |
| pred_cfg = config["predictor"] | |
| predictor = ARPredictor( | |
| num_frames=pred_cfg["num_frames"], | |
| input_dim=pred_cfg["input_dim"], | |
| hidden_dim=pred_cfg["hidden_dim"], | |
| output_dim=pred_cfg["output_dim"], | |
| depth=pred_cfg["depth"], | |
| heads=pred_cfg["heads"], | |
| mlp_dim=pred_cfg["mlp_dim"], | |
| dim_head=pred_cfg["dim_head"], | |
| dropout=pred_cfg["dropout"], | |
| emb_dropout=pred_cfg["emb_dropout"], | |
| ) | |
| act_cfg = config["action_encoder"] | |
| action_encoder = Embedder( | |
| input_dim=act_cfg["input_dim"], | |
| emb_dim=act_cfg["emb_dim"], | |
| ) | |
| proj_cfg = config["projector"] | |
| projector = MLP( | |
| input_dim=proj_cfg["input_dim"], | |
| output_dim=proj_cfg["output_dim"], | |
| hidden_dim=proj_cfg["hidden_dim"], | |
| norm_fn=torch.nn.BatchNorm1d, | |
| ) | |
| pp_cfg = config["pred_proj"] | |
| pred_proj = MLP( | |
| input_dim=pp_cfg["input_dim"], | |
| output_dim=pp_cfg["output_dim"], | |
| hidden_dim=pp_cfg["hidden_dim"], | |
| norm_fn=torch.nn.BatchNorm1d, | |
| ) | |
| return JEPA( | |
| encoder=encoder, | |
| predictor=predictor, | |
| action_encoder=action_encoder, | |
| projector=projector, | |
| pred_proj=pred_proj, | |
| ) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--weights", required=True, help="Path to weights.pt") | |
| parser.add_argument("--config", required=True, help="Path to config.json") | |
| parser.add_argument("--output", required=True, help="Output _object.ckpt path") | |
| args = parser.parse_args() | |
| with open(args.config) as f: | |
| config = json.load(f) | |
| print("Building model from config...") | |
| model = build_model(config) | |
| print(f"Loading weights from {args.weights}...") | |
| state_dict = torch.load(args.weights, map_location="cpu", weights_only=True) | |
| # Check for key mismatches before loading | |
| model_keys = set(model.state_dict().keys()) | |
| weight_keys = set(state_dict.keys()) | |
| missing = model_keys - weight_keys | |
| unexpected = weight_keys - model_keys | |
| if missing: | |
| print(f"WARNING: {len(missing)} keys in model but not in weights:") | |
| for k in sorted(missing)[:10]: | |
| print(f" {k}") | |
| if unexpected: | |
| print(f"WARNING: {len(unexpected)} keys in weights but not in model:") | |
| for k in sorted(unexpected)[:10]: | |
| print(f" {k}") | |
| model.load_state_dict(state_dict, strict=True) | |
| print(f"Loaded {len(state_dict)} parameters successfully.") | |
| output_path = Path(args.output) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save(model, output_path) | |
| print(f"Saved object checkpoint to {output_path}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment