Skip to content

Instantly share code, notes, and snippets.

@vonHartz
Created December 9, 2024 14:00
Show Gist options
  • Save vonHartz/a5a910694ca1e4bd5fa3856f785753be to your computer and use it in GitHub Desktop.
Save vonHartz/a5a910694ca1e4bd5fa3856f785753be to your computer and use it in GitHub Desktop.
Minimal example for curobo trajectory cost
from functools import partialmethod
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import torch
import torch.autograd.profiler as profiler
from curobo.cuda_robot_model.cuda_robot_model import (
CudaRobotModel,
CudaRobotModelConfig,
)
from curobo.curobolib.geom import geom_cu
from curobo.rollout.arm_reacher import (
ArmReacher,
_compute_g_dist_jit,
cat_sum_horizon_reacher,
cat_sum_reacher,
)
from curobo.rollout.dynamics_model.kinematic_model import KinematicModelState
from curobo.types.base import TensorDeviceType
from curobo.types.math import Pose
from curobo.types.robot import JointState, RobotConfig
from curobo.util_file import get_robot_path, join_path, load_yaml
from curobo.wrap.reacher.motion_gen import (
MotionGen,
MotionGenConfig,
MotionGenPlanConfig,
)
import pickle
dim_colors = tuple(("tab:red", "tab:green", "tab:blue"))
quat_colors = tuple(("tab:red", "tab:green", "tab:blue", "tab:brown"))
def custom_cost_fn(
self: ArmReacher,
state: KinematicModelState,
action_batch=None,
ref_traj_mu: torch.Tensor | None = None,
weights: torch.Tensor | None = None,
):
"""
Compute cost given that state dictionary and actions
:class:`curobo.rollout.cost.PoseCost`
:class:`curobo.rollout.cost.DistCost`
"""
state_batch = state.state_seq
with profiler.record_function("cost/base"):
cost_list = super(ArmReacher, self).cost_fn(
state, action_batch, return_list=True
)
ee_pos_batch, ee_quat_batch = state.ee_pos_seq, state.ee_quat_seq
g_dist = None
print("============================================================")
with profiler.record_function("cost/pose"):
if (
self._goal_buffer.goal_pose.position is not None
and self.cost_cfg.pose_cfg is not None
and self.goal_cost.enabled
):
# raise DidNotCheckThisBranchException
if self._compute_g_dist:
goal_cost, rot_err_norm, goal_dist = (
self.goal_cost.forward_out_distance(
ee_pos_batch,
ee_quat_batch,
self._goal_buffer,
)
)
g_dist = _compute_g_dist_jit(rot_err_norm, goal_dist)
else:
goal_cost = self.goal_cost.forward(
ee_pos_batch, ee_quat_batch, self._goal_buffer
)
cost_list.append(goal_cost)
applied_goal_cost = True
else:
applied_goal_cost = False
with profiler.record_function("cost/link_poses"):
if (
self._goal_buffer.links_goal_pose is not None
and self.cost_cfg.pose_cfg is not None
):
link_poses = state.link_pose
# print(self._goal_buffer.links_goal_pose.keys())
# raise DidNotCheckThisBranchException
for k in self._goal_buffer.links_goal_pose.keys():
if k != self.kinematics.ee_link:
current_fn = self._link_pose_costs[k]
if current_fn.enabled:
# get link pose
current_pose = link_poses[k].contiguous()
current_pos = current_pose.position
current_quat = current_pose.quaternion
c = current_fn.forward(
current_pos, current_quat, self._goal_buffer, k
)
cost_list.append(c)
if (
self._goal_buffer.goal_state is not None
and self.cost_cfg.cspace_cfg is not None
and self.dist_cost.enabled
):
joint_cost = self.dist_cost.forward_target_idx(
self._goal_buffer.goal_state.position,
state_batch.position,
self._goal_buffer.batch_goal_state_idx,
)
cost_list.append(joint_cost)
if self.cost_cfg.straight_line_cfg is not None and self.straight_line_cost.enabled:
st_cost = self.straight_line_cost.forward(ee_pos_batch)
cost_list.append(st_cost)
if (
self.cost_cfg.zero_acc_cfg is not None
and self.zero_acc_cost.enabled
# and g_dist is not None
):
z_acc = self.zero_acc_cost.forward(
state_batch.acceleration,
g_dist,
)
cost_list.append(z_acc)
if self.cost_cfg.zero_jerk_cfg is not None and self.zero_jerk_cost.enabled:
z_jerk = self.zero_jerk_cost.forward(
state_batch.jerk,
g_dist,
)
cost_list.append(z_jerk)
if self.cost_cfg.zero_vel_cfg is not None and self.zero_vel_cost.enabled:
z_vel = self.zero_vel_cost.forward(
state_batch.velocity,
g_dist,
)
cost_list.append(z_vel)
if True: # not applied_goal_cost:
new_cost = 1000.0 * torch.linalg.norm(ref_traj_mu[:, :3] - ee_pos_batch, dim=-1)
cost_list.append(new_cost)
with profiler.record_function("cat_sum"):
if self.sum_horizon:
cost = cat_sum_horizon_reacher(cost_list)
else:
cost = cat_sum_reacher(cost_list)
return cost
class CuroboModel:
def __init__(self) -> None:
self.tensor_args = TensorDeviceType()
config_path = "<INSERT_PARENT_DIR_OF_rlbench_panda.yml_HERE>/rlbench_panda.yml"
self.config_file = load_yaml(config_path)
self.urdf_file = self.config_file["robot_cfg"]["kinematics"]["urdf_path"]
self.base_link = self.config_file["robot_cfg"]["kinematics"]["base_link"]
self.ee_link = self.config_file["robot_cfg"]["kinematics"]["ee_link"]
# print("Urdf file: ", urdf_file)
# print("Base link: ", base_link)
# print("EE link: ", ee_link)
self.robot_cfg = RobotConfig.from_basic(
self.urdf_file, self.base_link, self.ee_link, self.tensor_args
)
self.model = CudaRobotModel(self.robot_cfg.kinematics)
def optimize_trajectory(
self,
ee_poses: torch.Tensor,
reference_qpos: torch.Tensor,
weights: torch.Tensor | None = None,
world_config: dict | None = None,
dt: float = 0.05,
):
timesteps = ee_poses.shape[0]
world_config = world_config or {}
from curobo.rollout.arm_reacher import ArmReacher
ArmReacher.cost_fn = partialmethod(
custom_cost_fn,
ref_traj_mu=ee_poses,
weights=weights
)
ArmReacher.trajectory_cost = UncertaintyAwareTrajectoryCost
motion_gen_config = MotionGenConfig.load_from_robot_config(
self.robot_cfg,
world_config,
interpolation_dt=dt,
trajopt_tsteps=timesteps,
interpolation_steps=timesteps,
finetune_trajopt_iters=200,
grad_trajopt_iters=200,
store_ik_debug=True,
store_trajopt_debug=True,
)
motion_gen = MotionGen(motion_gen_config)
motion_gen.warmup()
ref_qpos_tensor = torch.Tensor(reference_qpos[:7]).unsqueeze(0).cuda()
goal_pose = Pose.from_list(ee_poses[-1])
print("Goal Pose: ", goal_pose)
start_state = JointState.from_position(
ref_qpos_tensor,
joint_names=[
"panda_joint1",
"panda_joint2",
"panda_joint3",
"panda_joint4",
"panda_joint5",
"panda_joint6",
"panda_joint7",
],
)
plan_cfg = MotionGenPlanConfig(max_attempts=1)
result = motion_gen.plan_single(start_state, goal_pose, plan_cfg)
if not result.success:
dbg_file = "debug_info.txt"
with open(dbg_file, "w") as f:
f.write(result.debug_info)
raise ValueError("Failed to generate trajectory")
traj = (
result.get_interpolated_plan()
)
return traj
with open("info_dict.pkl", "rb") as f:
info = pickle.load(f)
target_poses = info["ee_target_poses"]
reference_qpos = info["initial_qpos"]
cr_model = CuroboModel()
target_tensor = torch.Tensor(target_poses).to(
device=cr_model.tensor_args.device
)
traj = cr_model.optimize_trajectory(
ee_poses=target_tensor,
reference_qpos=reference_qpos,
)
robot_cfg:
kinematics:
use_usd_kinematics: False
isaac_usd_path: "/Isaac/Robots/Franka/franka.usd"
usd_path: "robot/franka_description/franka_panda_meters.usda"
usd_robot_root: "/panda"
usd_flip_joints: ["panda_joint1","panda_joint2","panda_joint3","panda_joint4", "panda_joint5", "panda_joint6","panda_joint7","panda_finger_joint1", "panda_finger_joint2"]
usd_flip_joints: {
"panda_joint1": "Z",
"panda_joint2": "Z",
"panda_joint3": "Z",
"panda_joint4": "Z",
"panda_joint5": "Z",
"panda_joint6": "Z",
"panda_joint7": "Z",
"panda_finger_joint1": "Y",
"panda_finger_joint2": "Y",
}
usd_flip_joint_limits: ["panda_finger_joint2"]
urdf_path: "<INSERT_ROOT_PATH_OF_RLBench_HERE>/urdfs/panda/panda.urdf"
asset_root_path: "<INSERT_ROOT_PATH_OF_RLBench_HERE>/urdfs/panda/"
base_link: "robot_base"
ee_link: "Pandatip"
# link_names: null
collision_link_names: [
"robot_base", "Pandalink1respondable", "Pandalink2respondable", "Pandalink3respondable", "Pandalink4respondable", "Pandalink5respondable", "Pandalink6respondable", "Pandalink7respondable", "Pandagripper","Pandaleftfingerrespondable", "Pandarightfingerrespondable", "attached_object"]
# NOTE: spheres are from here https://github.com/NVlabs/curobo/blob/18e9ebd35fcc7e5fe3bedf4635c6ac8b701c19b2/src/curobo/content/configs/robot/spheres/franka.yml
# looks like I would need to re-name everything for the new link names
collision_spheres: "spheres/franka_mesh.yml"
collision_sphere_buffer: 0.003 #0.01
extra_collision_spheres: {"attached_object": 4}
use_global_cumul: True
self_collision_ignore: {
"panda_link0": ["panda_link1", "panda_link2"],
"panda_link1": ["panda_link2", "panda_link3", "panda_link4"],
"panda_link2": ["panda_link3", "panda_link4"],
"panda_link3": ["panda_link4", "panda_link6"],
"panda_link4":
["panda_link5", "panda_link6", "panda_link7", "panda_link8"],
"panda_link5": ["panda_link6", "panda_link7", "panda_hand","panda_leftfinger", "panda_rightfinger"],
"panda_link6": ["panda_link7", "panda_hand", "attached_object", "panda_leftfinger", "panda_rightfinger"],
"panda_link7": ["panda_hand", "attached_object", "panda_leftfinger", "panda_rightfinger"],
"panda_hand": ["panda_leftfinger", "panda_rightfinger","attached_object"],
"panda_leftfinger": ["panda_rightfinger", "attached_object"],
"panda_rightfinger": ["attached_object"],
}
self_collision_buffer: {
"panda_link0": 0.1,
"panda_link1": 0.05,
"panda_link2": 0.0,
"panda_link3": 0.0,
"panda_link4": 0.0,
"panda_link5": 0.0,
"panda_link6": 0.0,
"panda_link7": 0.0,
"panda_hand": 0.0,
"panda_leftfinger": 0.01,
"panda_rightfinger": 0.01,
"attached_object": 0.0,
}
mesh_link_names: [
"panda_link0",
"panda_link1",
"panda_link2",
"panda_link3",
"panda_link4",
"panda_link5",
"panda_link6",
"panda_link7",
"panda_hand",
"panda_leftfinger",
"panda_rightfinger",
]
lock_joints: {"panda_finger_joint1": 0.04, "panda_finger_joint2": -0.04}
extra_links: {
"attached_object": {
"parent_link_name": "panda_hand" ,
"link_name": "attached_object",
"fixed_transform": [0,0,0,1,0,0,0],
"joint_type":"FIXED",
"joint_name": "attach_joint"
}
}
cspace:
joint_names: [
"panda_joint1",
"panda_joint2",
"panda_joint3",
"panda_joint4",
"panda_joint5",
"panda_joint6",
"panda_joint7",
"panda_finger_joint1",
"panda_finger_joint2"
]
retract_config: [0.0, -1.3, 0.0, -2.5, 0.0, 1.0, 0., 0.04, -0.04]
null_space_weight: [1,1,1,1,1,1,1,1,1]
cspace_distance_weight: [1,1,1,1,1,1,1,1,1]
max_acceleration: 15.0
max_jerk: 500.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment