Skip to content

Instantly share code, notes, and snippets.

@DuaneNielsen
Last active March 31, 2025 07:20
Show Gist options
  • Save DuaneNielsen/4bf80ba8447c909ef1ec7350c6994b6f to your computer and use it in GitHub Desktop.
Save DuaneNielsen/4bf80ba8447c909ef1ec7350c6994b6f to your computer and use it in GitHub Desktop.
2D Mujco MJX Demo with depth sensing and velocity control
import jax
import numpy as np
import matplotlib.pyplot as plt
import mujoco
from mujoco import mjx
import mediapy as media
from math import sin, cos
from tqdm import trange
xml = """
<mujoco model="simple_2d">
<compiler autolimits="true"/>
<option integrator="implicitfast"/>
<asset>
<material name="body_material" rgba="0.2 0.8 0.2 1"/>
<material name="obstacle_material" rgba="0.8 0.2 0.2 1"/>
<material name="floor_material" rgba="0.3 0.3 0.3 1"/>
</asset>
<option timestep="0.02">
<flag contact="disable" />
</option>
<default>
<joint damping="0.25" stiffness="0.0"/>
</default>
<worldbody>
"""
sensor_angle = 0.6
num_sensors = 128
xml += f"""
<site name="origin"/>
<light pos="0 0 3" dir="0 0 -1" diffuse="0.8 0.8 0.8"/>
<!-- stacked joint: hinge + slide -->
<body pos="0.0 0 0" name="vehicle">
<joint name="x_joint" type="slide" axis="1. 0. 0." range="-1 1"/>
<joint name="y_joint" type="slide" axis="0. 1. 0." range="-1 1"/>
<joint name="rot_joint" type="hinge" axis="0 0 1."/>
<site name="velocity_site" pos="0 0 0" size="0.01"/>
<frame pos="0 0.01 0" quat="-1 1 0 0">
"""
for i, theta in enumerate(np.linspace(start=-sensor_angle, stop=sensor_angle, num=num_sensors)):
xml += f"""
<site name="site_rangefinder{i}" quat="{cos(theta/2)} 0 {sin(theta/2)} 0" size="0.01" rgba="1 0 0 1"/>
"""
xml += f"""
</frame>
<geom type="box" pos="0 0 0" size=".0168 .01 .005" mass="0.1"/>
</body>
"""
obstacles = []
for x in np.linspace(-1., 1., 5):
for y in np.linspace(-1., 1, 5):
if x == 0. and y == 0.:
continue
obstacles.append([x, y, 0.07])
for i, (x, y, radius) in enumerate(obstacles):
xml += f"""
<!-- Obstacle {i} -->
<geom name="obstacle_{i}" type="sphere" pos="{x} {y} 0"
size="{radius}" contype="1" conaffinity="1" material="obstacle_material"/>
"""
xml += """
</worldbody>
<sensor>
"""
for i in range(num_sensors):
xml += f"""
<rangefinder name="rangefinder{i}" site="site_rangefinder{i}"/>
"""
xml += """
<framequat name="vehicle_quat" objtype="site" objname="velocity_site"/>
</sensor>
<actuator>
<!-- Forward/backward velocity control in body frame -->
<velocity name="body_y" site='velocity_site' kv="1." gear="0 1 0 0 0 0" ctrlrange="-2 2"/>
<!-- Angular velocity control around Z axis in body frame -->
<velocity name="angular_velocity" joint="rot_joint" kv="1." ctrlrange="-1 1"/>
</actuator>
</mujoco>
"""
# Create MuJoCo model and data
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
# Transfer model and data to MJX
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
# JIT-compile step function
jit_step = jax.jit(mjx.step)
# Simulation parameters
duration = 20.0 # seconds
framerate = 30 # fps
n_frames = int(duration * framerate)
dt = mj_model.opt.timestep
steps_per_frame = max(1, int(1.0 / (framerate * dt)))
# Create visualization options
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_RANGEFINDER] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True
scene_option.frame = mujoco.mjtFrame.mjFRAME_SITE
# Renderer dimensions - match with the offscreen buffer size
width, height = 480, 480 # swapped to match the XML sizes
# Prepare data recording
time_points = []
rangefinder_data = []
joint_angle_data = []
# Reset simulation
mujoco.mj_resetData(mj_model, mj_data)
mjx_data = mjx.put_data(mj_model, mj_data)
# Render and simulate
frames = []
with mujoco.Renderer(mj_model, height, width) as renderer:
# Position the camera for better view
cam = mujoco.MjvCamera()
cam.azimuth = 90
cam.elevation = -50
cam.distance = 3.5
cam.lookat = np.array([0, 0, 0.2])
target_vel, target_rotation_vel = 0.4, 1.
for i in trange(n_frames):
ctrl = jax.numpy.array([target_vel, target_rotation_vel])
mjx_data = mjx_data.replace(ctrl=ctrl)
# Run multiple steps between frames
for _ in range(steps_per_frame):
mjx_data = jit_step(mjx_model, mjx_data)
# Get data back to CPU
mj_data = mjx.get_data(mj_model, mjx_data)
# Record data
time_points.append(mj_data.time)
rangefinder_data.append([mj_data.sensor(f'rangefinder{i}').data.item() for i in range(num_sensors)])
joint_angle_data.append(mj_data.qpos[0])
# Render the frame
renderer.update_scene(mj_data, camera=cam, scene_option=scene_option)
pixels = renderer.render()
frames.append(pixels)
# Create video file
output_filename = "forestnav_v1.mp4"
media.write_video(output_filename, frames, fps=framerate)
print(f"Video saved to {output_filename}")
# Plot rangefinder readings and joint position
plt.figure(figsize=(12, 8))
#
# Plot rangefinder readings
plt.subplot(2, 1, 1)
rangefinder_data = np.array(rangefinder_data)
for i in range(num_sensors):
plt.plot(time_points, rangefinder_data[:, i], label=f'Rangefinder {i}', linewidth=2)
plt.xlabel('Time (s)')
plt.ylabel('Distance (m)')
plt.title('Rangefinder Readings')
plt.legend()
plt.grid(True)
#
# # Plot joint angle
# plt.subplot(2, 1, 2)
# plt.plot(time_points, joint_angle_data, label='Joint Angle', color='green', linewidth=2)
# plt.xlabel('Time (s)')
# plt.ylabel('Angle (rad)')
# plt.title('Joint Angle')
# plt.grid(True)
#
plt.tight_layout()
plt.savefig('rangefinder_data.png')
plt.show()
print("Simulation complete!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment