Created
December 9, 2024 14:00
-
-
Save vonHartz/a5a910694ca1e4bd5fa3856f785753be to your computer and use it in GitHub Desktop.
Minimal example for curobo trajectory cost
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partialmethod | |
from typing import Optional, Tuple | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.autograd.profiler as profiler | |
from curobo.cuda_robot_model.cuda_robot_model import ( | |
CudaRobotModel, | |
CudaRobotModelConfig, | |
) | |
from curobo.curobolib.geom import geom_cu | |
from curobo.rollout.arm_reacher import ( | |
ArmReacher, | |
_compute_g_dist_jit, | |
cat_sum_horizon_reacher, | |
cat_sum_reacher, | |
) | |
from curobo.rollout.dynamics_model.kinematic_model import KinematicModelState | |
from curobo.types.base import TensorDeviceType | |
from curobo.types.math import Pose | |
from curobo.types.robot import JointState, RobotConfig | |
from curobo.util_file import get_robot_path, join_path, load_yaml | |
from curobo.wrap.reacher.motion_gen import ( | |
MotionGen, | |
MotionGenConfig, | |
MotionGenPlanConfig, | |
) | |
import pickle | |
dim_colors = tuple(("tab:red", "tab:green", "tab:blue")) | |
quat_colors = tuple(("tab:red", "tab:green", "tab:blue", "tab:brown")) | |
def custom_cost_fn( | |
self: ArmReacher, | |
state: KinematicModelState, | |
action_batch=None, | |
ref_traj_mu: torch.Tensor | None = None, | |
weights: torch.Tensor | None = None, | |
): | |
""" | |
Compute cost given that state dictionary and actions | |
:class:`curobo.rollout.cost.PoseCost` | |
:class:`curobo.rollout.cost.DistCost` | |
""" | |
state_batch = state.state_seq | |
with profiler.record_function("cost/base"): | |
cost_list = super(ArmReacher, self).cost_fn( | |
state, action_batch, return_list=True | |
) | |
ee_pos_batch, ee_quat_batch = state.ee_pos_seq, state.ee_quat_seq | |
g_dist = None | |
print("============================================================") | |
with profiler.record_function("cost/pose"): | |
if ( | |
self._goal_buffer.goal_pose.position is not None | |
and self.cost_cfg.pose_cfg is not None | |
and self.goal_cost.enabled | |
): | |
# raise DidNotCheckThisBranchException | |
if self._compute_g_dist: | |
goal_cost, rot_err_norm, goal_dist = ( | |
self.goal_cost.forward_out_distance( | |
ee_pos_batch, | |
ee_quat_batch, | |
self._goal_buffer, | |
) | |
) | |
g_dist = _compute_g_dist_jit(rot_err_norm, goal_dist) | |
else: | |
goal_cost = self.goal_cost.forward( | |
ee_pos_batch, ee_quat_batch, self._goal_buffer | |
) | |
cost_list.append(goal_cost) | |
applied_goal_cost = True | |
else: | |
applied_goal_cost = False | |
with profiler.record_function("cost/link_poses"): | |
if ( | |
self._goal_buffer.links_goal_pose is not None | |
and self.cost_cfg.pose_cfg is not None | |
): | |
link_poses = state.link_pose | |
# print(self._goal_buffer.links_goal_pose.keys()) | |
# raise DidNotCheckThisBranchException | |
for k in self._goal_buffer.links_goal_pose.keys(): | |
if k != self.kinematics.ee_link: | |
current_fn = self._link_pose_costs[k] | |
if current_fn.enabled: | |
# get link pose | |
current_pose = link_poses[k].contiguous() | |
current_pos = current_pose.position | |
current_quat = current_pose.quaternion | |
c = current_fn.forward( | |
current_pos, current_quat, self._goal_buffer, k | |
) | |
cost_list.append(c) | |
if ( | |
self._goal_buffer.goal_state is not None | |
and self.cost_cfg.cspace_cfg is not None | |
and self.dist_cost.enabled | |
): | |
joint_cost = self.dist_cost.forward_target_idx( | |
self._goal_buffer.goal_state.position, | |
state_batch.position, | |
self._goal_buffer.batch_goal_state_idx, | |
) | |
cost_list.append(joint_cost) | |
if self.cost_cfg.straight_line_cfg is not None and self.straight_line_cost.enabled: | |
st_cost = self.straight_line_cost.forward(ee_pos_batch) | |
cost_list.append(st_cost) | |
if ( | |
self.cost_cfg.zero_acc_cfg is not None | |
and self.zero_acc_cost.enabled | |
# and g_dist is not None | |
): | |
z_acc = self.zero_acc_cost.forward( | |
state_batch.acceleration, | |
g_dist, | |
) | |
cost_list.append(z_acc) | |
if self.cost_cfg.zero_jerk_cfg is not None and self.zero_jerk_cost.enabled: | |
z_jerk = self.zero_jerk_cost.forward( | |
state_batch.jerk, | |
g_dist, | |
) | |
cost_list.append(z_jerk) | |
if self.cost_cfg.zero_vel_cfg is not None and self.zero_vel_cost.enabled: | |
z_vel = self.zero_vel_cost.forward( | |
state_batch.velocity, | |
g_dist, | |
) | |
cost_list.append(z_vel) | |
if True: # not applied_goal_cost: | |
new_cost = 1000.0 * torch.linalg.norm(ref_traj_mu[:, :3] - ee_pos_batch, dim=-1) | |
cost_list.append(new_cost) | |
with profiler.record_function("cat_sum"): | |
if self.sum_horizon: | |
cost = cat_sum_horizon_reacher(cost_list) | |
else: | |
cost = cat_sum_reacher(cost_list) | |
return cost | |
class CuroboModel: | |
def __init__(self) -> None: | |
self.tensor_args = TensorDeviceType() | |
config_path = "<INSERT_PARENT_DIR_OF_rlbench_panda.yml_HERE>/rlbench_panda.yml" | |
self.config_file = load_yaml(config_path) | |
self.urdf_file = self.config_file["robot_cfg"]["kinematics"]["urdf_path"] | |
self.base_link = self.config_file["robot_cfg"]["kinematics"]["base_link"] | |
self.ee_link = self.config_file["robot_cfg"]["kinematics"]["ee_link"] | |
# print("Urdf file: ", urdf_file) | |
# print("Base link: ", base_link) | |
# print("EE link: ", ee_link) | |
self.robot_cfg = RobotConfig.from_basic( | |
self.urdf_file, self.base_link, self.ee_link, self.tensor_args | |
) | |
self.model = CudaRobotModel(self.robot_cfg.kinematics) | |
def optimize_trajectory( | |
self, | |
ee_poses: torch.Tensor, | |
reference_qpos: torch.Tensor, | |
weights: torch.Tensor | None = None, | |
world_config: dict | None = None, | |
dt: float = 0.05, | |
): | |
timesteps = ee_poses.shape[0] | |
world_config = world_config or {} | |
from curobo.rollout.arm_reacher import ArmReacher | |
ArmReacher.cost_fn = partialmethod( | |
custom_cost_fn, | |
ref_traj_mu=ee_poses, | |
weights=weights | |
) | |
ArmReacher.trajectory_cost = UncertaintyAwareTrajectoryCost | |
motion_gen_config = MotionGenConfig.load_from_robot_config( | |
self.robot_cfg, | |
world_config, | |
interpolation_dt=dt, | |
trajopt_tsteps=timesteps, | |
interpolation_steps=timesteps, | |
finetune_trajopt_iters=200, | |
grad_trajopt_iters=200, | |
store_ik_debug=True, | |
store_trajopt_debug=True, | |
) | |
motion_gen = MotionGen(motion_gen_config) | |
motion_gen.warmup() | |
ref_qpos_tensor = torch.Tensor(reference_qpos[:7]).unsqueeze(0).cuda() | |
goal_pose = Pose.from_list(ee_poses[-1]) | |
print("Goal Pose: ", goal_pose) | |
start_state = JointState.from_position( | |
ref_qpos_tensor, | |
joint_names=[ | |
"panda_joint1", | |
"panda_joint2", | |
"panda_joint3", | |
"panda_joint4", | |
"panda_joint5", | |
"panda_joint6", | |
"panda_joint7", | |
], | |
) | |
plan_cfg = MotionGenPlanConfig(max_attempts=1) | |
result = motion_gen.plan_single(start_state, goal_pose, plan_cfg) | |
if not result.success: | |
dbg_file = "debug_info.txt" | |
with open(dbg_file, "w") as f: | |
f.write(result.debug_info) | |
raise ValueError("Failed to generate trajectory") | |
traj = ( | |
result.get_interpolated_plan() | |
) | |
return traj | |
with open("info_dict.pkl", "rb") as f: | |
info = pickle.load(f) | |
target_poses = info["ee_target_poses"] | |
reference_qpos = info["initial_qpos"] | |
cr_model = CuroboModel() | |
target_tensor = torch.Tensor(target_poses).to( | |
device=cr_model.tensor_args.device | |
) | |
traj = cr_model.optimize_trajectory( | |
ee_poses=target_tensor, | |
reference_qpos=reference_qpos, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
robot_cfg: | |
kinematics: | |
use_usd_kinematics: False | |
isaac_usd_path: "/Isaac/Robots/Franka/franka.usd" | |
usd_path: "robot/franka_description/franka_panda_meters.usda" | |
usd_robot_root: "/panda" | |
usd_flip_joints: ["panda_joint1","panda_joint2","panda_joint3","panda_joint4", "panda_joint5", "panda_joint6","panda_joint7","panda_finger_joint1", "panda_finger_joint2"] | |
usd_flip_joints: { | |
"panda_joint1": "Z", | |
"panda_joint2": "Z", | |
"panda_joint3": "Z", | |
"panda_joint4": "Z", | |
"panda_joint5": "Z", | |
"panda_joint6": "Z", | |
"panda_joint7": "Z", | |
"panda_finger_joint1": "Y", | |
"panda_finger_joint2": "Y", | |
} | |
usd_flip_joint_limits: ["panda_finger_joint2"] | |
urdf_path: "<INSERT_ROOT_PATH_OF_RLBench_HERE>/urdfs/panda/panda.urdf" | |
asset_root_path: "<INSERT_ROOT_PATH_OF_RLBench_HERE>/urdfs/panda/" | |
base_link: "robot_base" | |
ee_link: "Pandatip" | |
# link_names: null | |
collision_link_names: [ | |
"robot_base", "Pandalink1respondable", "Pandalink2respondable", "Pandalink3respondable", "Pandalink4respondable", "Pandalink5respondable", "Pandalink6respondable", "Pandalink7respondable", "Pandagripper","Pandaleftfingerrespondable", "Pandarightfingerrespondable", "attached_object"] | |
# NOTE: spheres are from here https://github.com/NVlabs/curobo/blob/18e9ebd35fcc7e5fe3bedf4635c6ac8b701c19b2/src/curobo/content/configs/robot/spheres/franka.yml | |
# looks like I would need to re-name everything for the new link names | |
collision_spheres: "spheres/franka_mesh.yml" | |
collision_sphere_buffer: 0.003 #0.01 | |
extra_collision_spheres: {"attached_object": 4} | |
use_global_cumul: True | |
self_collision_ignore: { | |
"panda_link0": ["panda_link1", "panda_link2"], | |
"panda_link1": ["panda_link2", "panda_link3", "panda_link4"], | |
"panda_link2": ["panda_link3", "panda_link4"], | |
"panda_link3": ["panda_link4", "panda_link6"], | |
"panda_link4": | |
["panda_link5", "panda_link6", "panda_link7", "panda_link8"], | |
"panda_link5": ["panda_link6", "panda_link7", "panda_hand","panda_leftfinger", "panda_rightfinger"], | |
"panda_link6": ["panda_link7", "panda_hand", "attached_object", "panda_leftfinger", "panda_rightfinger"], | |
"panda_link7": ["panda_hand", "attached_object", "panda_leftfinger", "panda_rightfinger"], | |
"panda_hand": ["panda_leftfinger", "panda_rightfinger","attached_object"], | |
"panda_leftfinger": ["panda_rightfinger", "attached_object"], | |
"panda_rightfinger": ["attached_object"], | |
} | |
self_collision_buffer: { | |
"panda_link0": 0.1, | |
"panda_link1": 0.05, | |
"panda_link2": 0.0, | |
"panda_link3": 0.0, | |
"panda_link4": 0.0, | |
"panda_link5": 0.0, | |
"panda_link6": 0.0, | |
"panda_link7": 0.0, | |
"panda_hand": 0.0, | |
"panda_leftfinger": 0.01, | |
"panda_rightfinger": 0.01, | |
"attached_object": 0.0, | |
} | |
mesh_link_names: [ | |
"panda_link0", | |
"panda_link1", | |
"panda_link2", | |
"panda_link3", | |
"panda_link4", | |
"panda_link5", | |
"panda_link6", | |
"panda_link7", | |
"panda_hand", | |
"panda_leftfinger", | |
"panda_rightfinger", | |
] | |
lock_joints: {"panda_finger_joint1": 0.04, "panda_finger_joint2": -0.04} | |
extra_links: { | |
"attached_object": { | |
"parent_link_name": "panda_hand" , | |
"link_name": "attached_object", | |
"fixed_transform": [0,0,0,1,0,0,0], | |
"joint_type":"FIXED", | |
"joint_name": "attach_joint" | |
} | |
} | |
cspace: | |
joint_names: [ | |
"panda_joint1", | |
"panda_joint2", | |
"panda_joint3", | |
"panda_joint4", | |
"panda_joint5", | |
"panda_joint6", | |
"panda_joint7", | |
"panda_finger_joint1", | |
"panda_finger_joint2" | |
] | |
retract_config: [0.0, -1.3, 0.0, -2.5, 0.0, 1.0, 0., 0.04, -0.04] | |
null_space_weight: [1,1,1,1,1,1,1,1,1] | |
cspace_distance_weight: [1,1,1,1,1,1,1,1,1] | |
max_acceleration: 15.0 | |
max_jerk: 500.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment