Skip to content

Instantly share code, notes, and snippets.

@JosephJoshua
Created April 6, 2026 01:55
Show Gist options
  • Select an option

  • Save JosephJoshua/8aeb2ac5032d457d7dd66352a4bfbe0c to your computer and use it in GitHub Desktop.

Select an option

Save JosephJoshua/8aeb2ac5032d457d7dd66352a4bfbe0c to your computer and use it in GitHub Desktop.
"""Convert HuggingFace weights.pt → _object.ckpt for eval.py
Usage:
# Download weights first:
# pip install huggingface_hub
# huggingface-cli download quentinll/lewm-pusht --local-dir hf_pusht
python convert_hf_to_object.py \
--weights hf_pusht/weights.pt \
--config hf_pusht/config.json \
--output $STABLEWM_HOME/pusht/lewm_object.ckpt
"""
import argparse
import json
from pathlib import Path
import torch
import stable_pretraining as spt
from jepa import JEPA
from module import ARPredictor, Embedder, MLP
def build_model(config: dict) -> JEPA:
enc_cfg = config["encoder"]
encoder = spt.backbone.utils.vit_hf(
enc_cfg["size"],
patch_size=enc_cfg["patch_size"],
image_size=enc_cfg["image_size"],
pretrained=enc_cfg["pretrained"],
use_mask_token=enc_cfg["use_mask_token"],
)
pred_cfg = config["predictor"]
predictor = ARPredictor(
num_frames=pred_cfg["num_frames"],
input_dim=pred_cfg["input_dim"],
hidden_dim=pred_cfg["hidden_dim"],
output_dim=pred_cfg["output_dim"],
depth=pred_cfg["depth"],
heads=pred_cfg["heads"],
mlp_dim=pred_cfg["mlp_dim"],
dim_head=pred_cfg["dim_head"],
dropout=pred_cfg["dropout"],
emb_dropout=pred_cfg["emb_dropout"],
)
act_cfg = config["action_encoder"]
action_encoder = Embedder(
input_dim=act_cfg["input_dim"],
emb_dim=act_cfg["emb_dim"],
)
proj_cfg = config["projector"]
projector = MLP(
input_dim=proj_cfg["input_dim"],
output_dim=proj_cfg["output_dim"],
hidden_dim=proj_cfg["hidden_dim"],
norm_fn=torch.nn.BatchNorm1d,
)
pp_cfg = config["pred_proj"]
pred_proj = MLP(
input_dim=pp_cfg["input_dim"],
output_dim=pp_cfg["output_dim"],
hidden_dim=pp_cfg["hidden_dim"],
norm_fn=torch.nn.BatchNorm1d,
)
return JEPA(
encoder=encoder,
predictor=predictor,
action_encoder=action_encoder,
projector=projector,
pred_proj=pred_proj,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--weights", required=True, help="Path to weights.pt")
parser.add_argument("--config", required=True, help="Path to config.json")
parser.add_argument("--output", required=True, help="Output _object.ckpt path")
args = parser.parse_args()
with open(args.config) as f:
config = json.load(f)
print("Building model from config...")
model = build_model(config)
print(f"Loading weights from {args.weights}...")
state_dict = torch.load(args.weights, map_location="cpu", weights_only=True)
# Check for key mismatches before loading
model_keys = set(model.state_dict().keys())
weight_keys = set(state_dict.keys())
missing = model_keys - weight_keys
unexpected = weight_keys - model_keys
if missing:
print(f"WARNING: {len(missing)} keys in model but not in weights:")
for k in sorted(missing)[:10]:
print(f" {k}")
if unexpected:
print(f"WARNING: {len(unexpected)} keys in weights but not in model:")
for k in sorted(unexpected)[:10]:
print(f" {k}")
model.load_state_dict(state_dict, strict=True)
print(f"Loaded {len(state_dict)} parameters successfully.")
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model, output_path)
print(f"Saved object checkpoint to {output_path}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment