Created
March 4, 2025 19:33
-
-
Save shreyasgite/3de71719c1f03439ed7278b9ba85b14b to your computer and use it in GitHub Desktop.
Helper functions for augmenting robot trajectories for so100 robot and lerobot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def flip_frame(frame_data, flip=True): | |
""" | |
Create a flipped version of a frame by horizontally flipping camera images | |
and inverting the first dimension of action and state tensors. | |
Args: | |
frame_data: Dict containing frame data | |
flip: Boolean indicating whether to actually flip the frame (default: True) | |
If False, only format standardization is performed | |
Returns: | |
Dict containing flipped frame data | |
""" | |
import copy | |
import numpy as np | |
new_frame = copy.deepcopy(frame_data) | |
# Process each key in the frame | |
for key in new_frame: | |
# Flip camera images | |
if flip and ('observation.images' in key or key.startswith('observation.images.')): | |
img = new_frame[key] | |
# Handle different image formats | |
if hasattr(img, 'cpu'): # PyTorch tensor | |
img = img.cpu().numpy() | |
# Flip horizontally - assumes channel is first dimension | |
# For PyTorch tensors typically [C, H, W] | |
if len(img.shape) == 3: | |
new_frame[key] = np.flip(img, axis=2) | |
else: | |
new_frame[key] = np.flip(img, axis=1) | |
elif isinstance(img, np.ndarray): | |
# Flip horizontally | |
if len(img.shape) == 3: | |
# Determine which axis to flip based on image shape | |
# Typical image formats: [H, W, C] or [C, H, W] | |
flip_axis = 1 if img.shape[0] > img.shape[2] else 2 | |
new_frame[key] = np.flip(img, axis=flip_axis) | |
else: | |
new_frame[key] = np.flip(img, axis=1) | |
# If it's neither a tensor nor a numpy array, we don't know how to flip it | |
# Flip action vectors - invert first dimension | |
elif flip and key == 'action': | |
action = new_frame[key] | |
if hasattr(action, 'cpu'): # PyTorch tensor | |
action = action.cpu().numpy() | |
# Invert the first component (typically x-axis) | |
if len(action.shape) >= 1 and action.shape[0] > 0: | |
action[0] = -action[0] | |
new_frame[key] = action | |
elif isinstance(action, np.ndarray): | |
# Invert the first component (typically x-axis) | |
if len(action.shape) >= 1 and action.shape[0] > 0: | |
action_copy = action.copy() | |
action_copy[0] = -action_copy[0] | |
new_frame[key] = action_copy | |
else: | |
new_frame[key] = action | |
# Flip state vectors - invert relevant dimensions | |
elif flip and 'observation.state' in key: | |
state = new_frame[key] | |
if hasattr(state, 'cpu'): # PyTorch tensor | |
state = state.cpu().numpy() | |
# Invert the first component (typically x-axis) | |
if len(state.shape) >= 1 and state.shape[0] > 0: | |
state[0] = -state[0] | |
new_frame[key] = state | |
elif isinstance(state, np.ndarray): | |
# Invert the first component (typically x-axis) | |
if len(state.shape) >= 1 and state.shape[0] > 0: | |
state_copy = state.copy() | |
state_copy[0] = -state_copy[0] | |
new_frame[key] = state_copy | |
else: | |
new_frame[key] = state | |
# Ensure timestamp is a numpy array with correct shape | |
if 'timestamp' in new_frame: | |
timestamp = new_frame['timestamp'] | |
if not isinstance(timestamp, np.ndarray): | |
if hasattr(timestamp, 'cpu'): | |
timestamp = timestamp.cpu().numpy() | |
else: | |
timestamp = np.array([timestamp]) | |
# Ensure timestamp is flattened to 1D | |
new_frame['timestamp'] = timestamp.flatten() | |
# Add augmentation metadata | |
if flip: | |
aug_type = "flipped" | |
else: | |
aug_type = "original" | |
new_frame['augmentation'] = aug_type | |
return new_frame | |
def create_spliced_episodes(dataset, ep1_idx, flipped_frames, noise_scale=0.05, fps=30): | |
""" | |
Create two new episodes by splicing parts of two episodes at matching states. | |
Args: | |
dataset: LeRobotDataset containing the original episode | |
ep1_idx: Index of the first episode (original) | |
flipped_frames: List of tuples (episode_idx, frame_data) for flipped frames | |
noise_scale: Scale of random noise to add to action and state values (default: 0.05) | |
fps: Frames per second of the dataset (default: 30) | |
Returns: | |
Tuple of two lists: (episode3_frames, episode4_frames), each containing (episode_idx, frame_data) | |
tuples for the new spliced episodes. | |
""" | |
import numpy as np | |
from scipy.spatial.distance import cdist | |
import copy | |
import random | |
print(f"Creating spliced episodes from original episode {ep1_idx} and its flipped version") | |
# Extract frames from first episode (original) | |
episode1_frames = [] | |
ep1_start_idx = dataset.episode_data_index["from"][ep1_idx].item() | |
ep1_end_idx = dataset.episode_data_index["to"][ep1_idx].item() | |
for i in range(ep1_start_idx, ep1_end_idx): | |
frame_data = dataset[i] | |
episode1_frames.append(frame_data) | |
# The flipped frames are already provided as a list of tuples (episode_idx, frame_data) | |
# Extract just the frame data from the tuples | |
episode2_frames = [frame_data for _, frame_data in flipped_frames] | |
flipped_episode_idx = flipped_frames[0][0] if flipped_frames else ep1_idx + 100 | |
print(f"Episode 1 (original): {len(episode1_frames)} frames") | |
print(f"Episode 2 (flipped): {len(episode2_frames)} frames") | |
# Define safe regions for splicing (avoiding first and last 3 seconds) | |
safe_frames_buffer = 3 * fps | |
ep1_min_idx = safe_frames_buffer | |
ep1_max_idx = len(episode1_frames) - safe_frames_buffer | |
ep2_min_idx = safe_frames_buffer | |
ep2_max_idx = len(episode2_frames) - safe_frames_buffer | |
# Define middle 5 seconds for both episodes | |
ep1_mid_start = max(ep1_min_idx, len(episode1_frames) // 2 - 2.5 * fps) | |
ep1_mid_end = min(ep1_max_idx, len(episode1_frames) // 2 + 2.5 * fps) | |
ep2_mid_start = max(ep2_min_idx, len(episode2_frames) // 2 - 2.5 * fps) | |
ep2_mid_end = min(ep2_max_idx, len(episode2_frames) // 2 + 2.5 * fps) | |
# Convert to integers | |
ep1_mid_start = int(ep1_mid_start) | |
ep1_mid_end = int(ep1_mid_end) | |
ep2_mid_start = int(ep2_mid_start) | |
ep2_mid_end = int(ep2_mid_end) | |
print(f"Episode 1 middle section: frames {ep1_mid_start} to {ep1_mid_end}") | |
print(f"Episode 2 middle section: frames {ep2_mid_start} to {ep2_mid_end}") | |
# Extract state vectors from middle sections | |
ep1_mid_states = [] | |
for i in range(ep1_mid_start, ep1_mid_end): | |
if 'observation.state' in episode1_frames[i]: | |
state = episode1_frames[i]['observation.state'] | |
if isinstance(state, np.ndarray): | |
ep1_mid_states.append((i, state)) | |
else: | |
# Handle tensor case | |
ep1_mid_states.append((i, state.cpu().numpy())) | |
ep2_mid_states = [] | |
for i in range(ep2_mid_start, ep2_mid_end): | |
if 'observation.state' in episode2_frames[i]: | |
state = episode2_frames[i]['observation.state'] | |
if isinstance(state, np.ndarray): | |
ep2_mid_states.append((i, state)) | |
else: | |
# Handle tensor case | |
ep2_mid_states.append((i, state.cpu().numpy())) | |
if not ep1_mid_states or not ep2_mid_states: | |
raise ValueError("Couldn't find state vectors in one or both episodes") | |
# Find best matching states | |
best_match = None | |
best_distance = float('inf') | |
# Create arrays of states for efficient distance calculation | |
ep1_state_idxs = [item[0] for item in ep1_mid_states] | |
ep1_states = np.array([item[1] for item in ep1_mid_states]) | |
ep2_state_idxs = [item[0] for item in ep2_mid_states] | |
ep2_states = np.array([item[1] for item in ep2_mid_states]) | |
# Compute pairwise Euclidean distances | |
distances = cdist(ep1_states, ep2_states, 'euclidean') | |
# Find the minimum distance | |
min_idx = np.argmin(distances) | |
min_row, min_col = np.unravel_index(min_idx, distances.shape) | |
best_distance = distances[min_row, min_col] | |
ep1_match_idx = ep1_state_idxs[min_row] | |
ep2_match_idx = ep2_state_idxs[min_col] | |
print(f"Best matching states found at:") | |
print(f" Episode 1 frame {ep1_match_idx}") | |
print(f" Episode 2 frame {ep2_match_idx}") | |
print(f" Distance: {best_distance}") | |
# Create spliced episodes | |
# Episode 3: First part from ep1, second part from ep2 | |
episode3_frames = [] | |
episode3_idx = flipped_episode_idx + 100 # Use a high index to avoid conflicts | |
# Add frames from episode 1 up to the matching point | |
for i in range(ep1_match_idx + 1): | |
frame = copy.deepcopy(episode1_frames[i]) | |
# Add noise to action and state if present | |
if 'action' in frame: | |
if isinstance(frame['action'], np.ndarray): | |
noise = np.random.normal(0, noise_scale, frame['action'].shape) | |
frame['action'] = frame['action'] + noise | |
else: | |
# Handle tensor case | |
action = frame['action'].cpu().numpy() | |
noise = np.random.normal(0, noise_scale, action.shape) | |
frame['action'] = action + noise | |
if 'observation.state' in frame: | |
if isinstance(frame['observation.state'], np.ndarray): | |
noise = np.random.normal(0, noise_scale, frame['observation.state'].shape) | |
frame['observation.state'] = frame['observation.state'] + noise | |
else: | |
# Handle tensor case | |
state = frame['observation.state'].cpu().numpy() | |
noise = np.random.normal(0, noise_scale, state.shape) | |
frame['observation.state'] = state + noise | |
# Process frame using flip_frame to ensure consistent format | |
processed_frame = flip_frame(frame) | |
episode3_frames.append((episode3_idx, processed_frame)) | |
# Calculate time offset for smooth transitions | |
if 'timestamp' in episode1_frames[ep1_match_idx] and 'timestamp' in episode2_frames[ep2_match_idx]: | |
ep1_last_time = episode1_frames[ep1_match_idx]['timestamp'] | |
ep2_start_time = episode2_frames[ep2_match_idx]['timestamp'] | |
if isinstance(ep1_last_time, np.ndarray): | |
ep1_last_time = ep1_last_time.item() | |
elif hasattr(ep1_last_time, 'item'): | |
ep1_last_time = ep1_last_time.item() | |
if isinstance(ep2_start_time, np.ndarray): | |
ep2_start_time = ep2_start_time.item() | |
elif hasattr(ep2_start_time, 'item'): | |
ep2_start_time = ep2_start_time.item() | |
time_offset = ep1_last_time - ep2_start_time | |
else: | |
time_offset = 0.0 | |
# Add frames from episode 2 after the matching point | |
for i in range(ep2_match_idx, len(episode2_frames)): | |
frame = copy.deepcopy(episode2_frames[i]) | |
# Add noise to action and state if present | |
if 'action' in frame: | |
if isinstance(frame['action'], np.ndarray): | |
noise = np.random.normal(0, noise_scale, frame['action'].shape) | |
frame['action'] = frame['action'] + noise | |
else: | |
# Handle tensor case | |
action = frame['action'].cpu().numpy() | |
noise = np.random.normal(0, noise_scale, action.shape) | |
frame['action'] = action + noise | |
if 'observation.state' in frame: | |
if isinstance(frame['observation.state'], np.ndarray): | |
noise = np.random.normal(0, noise_scale, frame['observation.state'].shape) | |
frame['observation.state'] = frame['observation.state'] + noise | |
else: | |
# Handle tensor case | |
state = frame['observation.state'].cpu().numpy() | |
noise = np.random.normal(0, noise_scale, state.shape) | |
frame['observation.state'] = state + noise | |
# Adjust timestamp for smooth transition | |
if 'timestamp' in frame: | |
if isinstance(frame['timestamp'], np.ndarray): | |
frame['timestamp'] = frame['timestamp'] + time_offset | |
else: | |
frame['timestamp'] = frame['timestamp'] + time_offset | |
# Process frame using flip_frame to ensure consistent format | |
processed_frame = flip_frame(frame) | |
episode3_frames.append((episode3_idx, processed_frame)) | |
# Episode 4: First part from ep2, second part from ep1 | |
episode4_frames = [] | |
episode4_idx = flipped_episode_idx + 101 # Use a high index to avoid conflicts | |
# Add frames from episode 2 up to the matching point | |
for i in range(ep2_match_idx + 1): | |
frame = copy.deepcopy(episode2_frames[i]) | |
# Add noise to action and state if present | |
if 'action' in frame: | |
if isinstance(frame['action'], np.ndarray): | |
noise = np.random.normal(0, noise_scale, frame['action'].shape) | |
frame['action'] = frame['action'] + noise | |
else: | |
# Handle tensor case | |
action = frame['action'].cpu().numpy() | |
noise = np.random.normal(0, noise_scale, action.shape) | |
frame['action'] = action + noise | |
if 'observation.state' in frame: | |
if isinstance(frame['observation.state'], np.ndarray): | |
noise = np.random.normal(0, noise_scale, frame['observation.state'].shape) | |
frame['observation.state'] = frame['observation.state'] + noise | |
else: | |
# Handle tensor case | |
state = frame['observation.state'].cpu().numpy() | |
noise = np.random.normal(0, noise_scale, state.shape) | |
frame['observation.state'] = state + noise | |
# Process frame using flip_frame to ensure consistent format | |
processed_frame = flip_frame(frame) | |
episode4_frames.append((episode4_idx, processed_frame)) | |
# Calculate time offset for smooth transitions | |
if 'timestamp' in episode2_frames[ep2_match_idx] and 'timestamp' in episode1_frames[ep1_match_idx]: | |
ep2_last_time = episode2_frames[ep2_match_idx]['timestamp'] | |
ep1_start_time = episode1_frames[ep1_match_idx]['timestamp'] | |
if isinstance(ep2_last_time, np.ndarray): | |
ep2_last_time = ep2_last_time.item() | |
elif hasattr(ep2_last_time, 'item'): | |
ep2_last_time = ep2_last_time.item() | |
if isinstance(ep1_start_time, np.ndarray): | |
ep1_start_time = ep1_start_time.item() | |
elif hasattr(ep1_start_time, 'item'): | |
ep1_start_time = ep1_start_time.item() | |
time_offset = ep2_last_time - ep1_start_time | |
else: | |
time_offset = 0.0 | |
# Add frames from episode 1 after the matching point | |
for i in range(ep1_match_idx, len(episode1_frames)): | |
frame = copy.deepcopy(episode1_frames[i]) | |
# Add noise to action and state if present | |
if 'action' in frame: | |
if isinstance(frame['action'], np.ndarray): | |
noise = np.random.normal(0, noise_scale, frame['action'].shape) | |
frame['action'] = frame['action'] + noise | |
else: | |
# Handle tensor case | |
action = frame['action'].cpu().numpy() | |
noise = np.random.normal(0, noise_scale, action.shape) | |
frame['action'] = action + noise | |
if 'observation.state' in frame: | |
if isinstance(frame['observation.state'], np.ndarray): | |
noise = np.random.normal(0, noise_scale, frame['observation.state'].shape) | |
frame['observation.state'] = frame['observation.state'] + noise | |
else: | |
# Handle tensor case | |
state = frame['observation.state'].cpu().numpy() | |
noise = np.random.normal(0, noise_scale, state.shape) | |
frame['observation.state'] = state + noise | |
# Adjust timestamp for smooth transition | |
if 'timestamp' in frame: | |
if isinstance(frame['timestamp'], np.ndarray): | |
frame['timestamp'] = frame['timestamp'] + time_offset | |
else: | |
frame['timestamp'] = frame['timestamp'] + time_offset | |
# Process frame using flip_frame to ensure consistent format | |
processed_frame = flip_frame(frame) | |
episode4_frames.append((episode4_idx, processed_frame)) | |
print(f"Created Episode 3 with {len(episode3_frames)} frames") | |
print(f"Created Episode 4 with {len(episode4_frames)} frames") | |
return episode3_frames, episode4_frames |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment