Skip to content

Instantly share code, notes, and snippets.

@araffin
Last active June 10, 2025 14:09
Show Gist options
  • Save araffin/1fb77a8f290ac248b2e76e01164f21e0 to your computer and use it in GitHub Desktop.
Save araffin/1fb77a8f290ac248b2e76e01164f21e0 to your computer and use it in GitHub Desktop.
Minimal implementation to solve the HalfCheetah env using open-loop oscillators
# MIT License Copyright (c) 2024 Antonin Raffin
import gymnasium as gym
import numpy as np
from gymnasium.envs.mujoco.mujoco_env import MujocoEnv
# Env initialization
env = gym.make("HalfCheetah-v4", render_mode="human")
# Wrap to have reward statistics
env = gym.wrappers.RecordEpisodeStatistics(env)
mujoco_env = env.unwrapped
n_joints = 6
assert isinstance(mujoco_env, MujocoEnv)
# PD Controller gains
kp, kd = 1.0, 0.05
# Reset the environment
t, _ = 0.0, env.reset(seed=0)
# Oscillators parameters
omega_stance = 2 * np.pi * 4.622 * np.ones(n_joints)
omega_swing = 2 * np.pi * 3.865 * np.ones(n_joints)
phase_shifts = 2 * np.pi * np.array([0.00, 0.789, 0.316, 0.294, 0.629, 0.921])
amplitudes = np.array([1.123, -1.91, -1.204, 1.173, 1.196, -0.085])
offsets = np.array([-0.114, 0.075, 0.002, -0.493, -0.501, -0.227])
oscillator_dt = 0.001 # 1kHz, integration step
# Initial joint positions
theta = phase_shifts.copy()
while True:
env.render()
# Integrate oscillators equations
for _ in range(int(mujoco_env.dt / oscillator_dt)):
in_swing_phase = np.sin(theta) > 0
theta_dot = in_swing_phase * omega_swing + (1 - in_swing_phase) * omega_stance
# Integrate and keep theta in [0, 2 * pi]
theta = (theta + oscillator_dt * theta_dot) % (2 * np.pi)
# Open-Loop Control using oscillators
desired_qpos = amplitudes * np.sin(theta) + offsets
# PD Control: desired qvel is zero
desired_torques = (
kp * (desired_qpos - mujoco_env.data.qpos[-n_joints:])
- kd * mujoco_env.data.qvel[-n_joints:]
)
desired_torques = np.clip(desired_torques, -1.0, 1.0) # clip to action bounds
_, reward, terminated, truncated, info = env.step(desired_torques)
t += mujoco_env.dt
if terminated or truncated:
print(f"Episode return: {float(info['episode']['r'].item()):.2f}")
t, _ = 0.0, env.reset()
# Reinitialize
theta = phase_shifts.copy()
@araffin
Copy link
Author

araffin commented Mar 27, 2024

@ikamensh
Copy link

Hey, thanks for sharing! Under what license can we use this? I'd like to use it for some unit tests.

@araffin
Copy link
Author

araffin commented Jun 10, 2025

Hi,
please use MIT license.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment