Skip to content

Instantly share code, notes, and snippets.

@shreyasgite
Created March 4, 2025 19:33
Show Gist options
  • Save shreyasgite/3de71719c1f03439ed7278b9ba85b14b to your computer and use it in GitHub Desktop.
Save shreyasgite/3de71719c1f03439ed7278b9ba85b14b to your computer and use it in GitHub Desktop.
Helper functions for augmenting robot trajectories for so100 robot and lerobot
def flip_frame(frame_data, flip=True):
"""
Create a flipped version of a frame by horizontally flipping camera images
and inverting the first dimension of action and state tensors.
Args:
frame_data: Dict containing frame data
flip: Boolean indicating whether to actually flip the frame (default: True)
If False, only format standardization is performed
Returns:
Dict containing flipped frame data
"""
import copy
import numpy as np
new_frame = copy.deepcopy(frame_data)
# Process each key in the frame
for key in new_frame:
# Flip camera images
if flip and ('observation.images' in key or key.startswith('observation.images.')):
img = new_frame[key]
# Handle different image formats
if hasattr(img, 'cpu'): # PyTorch tensor
img = img.cpu().numpy()
# Flip horizontally - assumes channel is first dimension
# For PyTorch tensors typically [C, H, W]
if len(img.shape) == 3:
new_frame[key] = np.flip(img, axis=2)
else:
new_frame[key] = np.flip(img, axis=1)
elif isinstance(img, np.ndarray):
# Flip horizontally
if len(img.shape) == 3:
# Determine which axis to flip based on image shape
# Typical image formats: [H, W, C] or [C, H, W]
flip_axis = 1 if img.shape[0] > img.shape[2] else 2
new_frame[key] = np.flip(img, axis=flip_axis)
else:
new_frame[key] = np.flip(img, axis=1)
# If it's neither a tensor nor a numpy array, we don't know how to flip it
# Flip action vectors - invert first dimension
elif flip and key == 'action':
action = new_frame[key]
if hasattr(action, 'cpu'): # PyTorch tensor
action = action.cpu().numpy()
# Invert the first component (typically x-axis)
if len(action.shape) >= 1 and action.shape[0] > 0:
action[0] = -action[0]
new_frame[key] = action
elif isinstance(action, np.ndarray):
# Invert the first component (typically x-axis)
if len(action.shape) >= 1 and action.shape[0] > 0:
action_copy = action.copy()
action_copy[0] = -action_copy[0]
new_frame[key] = action_copy
else:
new_frame[key] = action
# Flip state vectors - invert relevant dimensions
elif flip and 'observation.state' in key:
state = new_frame[key]
if hasattr(state, 'cpu'): # PyTorch tensor
state = state.cpu().numpy()
# Invert the first component (typically x-axis)
if len(state.shape) >= 1 and state.shape[0] > 0:
state[0] = -state[0]
new_frame[key] = state
elif isinstance(state, np.ndarray):
# Invert the first component (typically x-axis)
if len(state.shape) >= 1 and state.shape[0] > 0:
state_copy = state.copy()
state_copy[0] = -state_copy[0]
new_frame[key] = state_copy
else:
new_frame[key] = state
# Ensure timestamp is a numpy array with correct shape
if 'timestamp' in new_frame:
timestamp = new_frame['timestamp']
if not isinstance(timestamp, np.ndarray):
if hasattr(timestamp, 'cpu'):
timestamp = timestamp.cpu().numpy()
else:
timestamp = np.array([timestamp])
# Ensure timestamp is flattened to 1D
new_frame['timestamp'] = timestamp.flatten()
# Add augmentation metadata
if flip:
aug_type = "flipped"
else:
aug_type = "original"
new_frame['augmentation'] = aug_type
return new_frame
def create_spliced_episodes(dataset, ep1_idx, flipped_frames, noise_scale=0.05, fps=30):
"""
Create two new episodes by splicing parts of two episodes at matching states.
Args:
dataset: LeRobotDataset containing the original episode
ep1_idx: Index of the first episode (original)
flipped_frames: List of tuples (episode_idx, frame_data) for flipped frames
noise_scale: Scale of random noise to add to action and state values (default: 0.05)
fps: Frames per second of the dataset (default: 30)
Returns:
Tuple of two lists: (episode3_frames, episode4_frames), each containing (episode_idx, frame_data)
tuples for the new spliced episodes.
"""
import numpy as np
from scipy.spatial.distance import cdist
import copy
import random
print(f"Creating spliced episodes from original episode {ep1_idx} and its flipped version")
# Extract frames from first episode (original)
episode1_frames = []
ep1_start_idx = dataset.episode_data_index["from"][ep1_idx].item()
ep1_end_idx = dataset.episode_data_index["to"][ep1_idx].item()
for i in range(ep1_start_idx, ep1_end_idx):
frame_data = dataset[i]
episode1_frames.append(frame_data)
# The flipped frames are already provided as a list of tuples (episode_idx, frame_data)
# Extract just the frame data from the tuples
episode2_frames = [frame_data for _, frame_data in flipped_frames]
flipped_episode_idx = flipped_frames[0][0] if flipped_frames else ep1_idx + 100
print(f"Episode 1 (original): {len(episode1_frames)} frames")
print(f"Episode 2 (flipped): {len(episode2_frames)} frames")
# Define safe regions for splicing (avoiding first and last 3 seconds)
safe_frames_buffer = 3 * fps
ep1_min_idx = safe_frames_buffer
ep1_max_idx = len(episode1_frames) - safe_frames_buffer
ep2_min_idx = safe_frames_buffer
ep2_max_idx = len(episode2_frames) - safe_frames_buffer
# Define middle 5 seconds for both episodes
ep1_mid_start = max(ep1_min_idx, len(episode1_frames) // 2 - 2.5 * fps)
ep1_mid_end = min(ep1_max_idx, len(episode1_frames) // 2 + 2.5 * fps)
ep2_mid_start = max(ep2_min_idx, len(episode2_frames) // 2 - 2.5 * fps)
ep2_mid_end = min(ep2_max_idx, len(episode2_frames) // 2 + 2.5 * fps)
# Convert to integers
ep1_mid_start = int(ep1_mid_start)
ep1_mid_end = int(ep1_mid_end)
ep2_mid_start = int(ep2_mid_start)
ep2_mid_end = int(ep2_mid_end)
print(f"Episode 1 middle section: frames {ep1_mid_start} to {ep1_mid_end}")
print(f"Episode 2 middle section: frames {ep2_mid_start} to {ep2_mid_end}")
# Extract state vectors from middle sections
ep1_mid_states = []
for i in range(ep1_mid_start, ep1_mid_end):
if 'observation.state' in episode1_frames[i]:
state = episode1_frames[i]['observation.state']
if isinstance(state, np.ndarray):
ep1_mid_states.append((i, state))
else:
# Handle tensor case
ep1_mid_states.append((i, state.cpu().numpy()))
ep2_mid_states = []
for i in range(ep2_mid_start, ep2_mid_end):
if 'observation.state' in episode2_frames[i]:
state = episode2_frames[i]['observation.state']
if isinstance(state, np.ndarray):
ep2_mid_states.append((i, state))
else:
# Handle tensor case
ep2_mid_states.append((i, state.cpu().numpy()))
if not ep1_mid_states or not ep2_mid_states:
raise ValueError("Couldn't find state vectors in one or both episodes")
# Find best matching states
best_match = None
best_distance = float('inf')
# Create arrays of states for efficient distance calculation
ep1_state_idxs = [item[0] for item in ep1_mid_states]
ep1_states = np.array([item[1] for item in ep1_mid_states])
ep2_state_idxs = [item[0] for item in ep2_mid_states]
ep2_states = np.array([item[1] for item in ep2_mid_states])
# Compute pairwise Euclidean distances
distances = cdist(ep1_states, ep2_states, 'euclidean')
# Find the minimum distance
min_idx = np.argmin(distances)
min_row, min_col = np.unravel_index(min_idx, distances.shape)
best_distance = distances[min_row, min_col]
ep1_match_idx = ep1_state_idxs[min_row]
ep2_match_idx = ep2_state_idxs[min_col]
print(f"Best matching states found at:")
print(f" Episode 1 frame {ep1_match_idx}")
print(f" Episode 2 frame {ep2_match_idx}")
print(f" Distance: {best_distance}")
# Create spliced episodes
# Episode 3: First part from ep1, second part from ep2
episode3_frames = []
episode3_idx = flipped_episode_idx + 100 # Use a high index to avoid conflicts
# Add frames from episode 1 up to the matching point
for i in range(ep1_match_idx + 1):
frame = copy.deepcopy(episode1_frames[i])
# Add noise to action and state if present
if 'action' in frame:
if isinstance(frame['action'], np.ndarray):
noise = np.random.normal(0, noise_scale, frame['action'].shape)
frame['action'] = frame['action'] + noise
else:
# Handle tensor case
action = frame['action'].cpu().numpy()
noise = np.random.normal(0, noise_scale, action.shape)
frame['action'] = action + noise
if 'observation.state' in frame:
if isinstance(frame['observation.state'], np.ndarray):
noise = np.random.normal(0, noise_scale, frame['observation.state'].shape)
frame['observation.state'] = frame['observation.state'] + noise
else:
# Handle tensor case
state = frame['observation.state'].cpu().numpy()
noise = np.random.normal(0, noise_scale, state.shape)
frame['observation.state'] = state + noise
# Process frame using flip_frame to ensure consistent format
processed_frame = flip_frame(frame)
episode3_frames.append((episode3_idx, processed_frame))
# Calculate time offset for smooth transitions
if 'timestamp' in episode1_frames[ep1_match_idx] and 'timestamp' in episode2_frames[ep2_match_idx]:
ep1_last_time = episode1_frames[ep1_match_idx]['timestamp']
ep2_start_time = episode2_frames[ep2_match_idx]['timestamp']
if isinstance(ep1_last_time, np.ndarray):
ep1_last_time = ep1_last_time.item()
elif hasattr(ep1_last_time, 'item'):
ep1_last_time = ep1_last_time.item()
if isinstance(ep2_start_time, np.ndarray):
ep2_start_time = ep2_start_time.item()
elif hasattr(ep2_start_time, 'item'):
ep2_start_time = ep2_start_time.item()
time_offset = ep1_last_time - ep2_start_time
else:
time_offset = 0.0
# Add frames from episode 2 after the matching point
for i in range(ep2_match_idx, len(episode2_frames)):
frame = copy.deepcopy(episode2_frames[i])
# Add noise to action and state if present
if 'action' in frame:
if isinstance(frame['action'], np.ndarray):
noise = np.random.normal(0, noise_scale, frame['action'].shape)
frame['action'] = frame['action'] + noise
else:
# Handle tensor case
action = frame['action'].cpu().numpy()
noise = np.random.normal(0, noise_scale, action.shape)
frame['action'] = action + noise
if 'observation.state' in frame:
if isinstance(frame['observation.state'], np.ndarray):
noise = np.random.normal(0, noise_scale, frame['observation.state'].shape)
frame['observation.state'] = frame['observation.state'] + noise
else:
# Handle tensor case
state = frame['observation.state'].cpu().numpy()
noise = np.random.normal(0, noise_scale, state.shape)
frame['observation.state'] = state + noise
# Adjust timestamp for smooth transition
if 'timestamp' in frame:
if isinstance(frame['timestamp'], np.ndarray):
frame['timestamp'] = frame['timestamp'] + time_offset
else:
frame['timestamp'] = frame['timestamp'] + time_offset
# Process frame using flip_frame to ensure consistent format
processed_frame = flip_frame(frame)
episode3_frames.append((episode3_idx, processed_frame))
# Episode 4: First part from ep2, second part from ep1
episode4_frames = []
episode4_idx = flipped_episode_idx + 101 # Use a high index to avoid conflicts
# Add frames from episode 2 up to the matching point
for i in range(ep2_match_idx + 1):
frame = copy.deepcopy(episode2_frames[i])
# Add noise to action and state if present
if 'action' in frame:
if isinstance(frame['action'], np.ndarray):
noise = np.random.normal(0, noise_scale, frame['action'].shape)
frame['action'] = frame['action'] + noise
else:
# Handle tensor case
action = frame['action'].cpu().numpy()
noise = np.random.normal(0, noise_scale, action.shape)
frame['action'] = action + noise
if 'observation.state' in frame:
if isinstance(frame['observation.state'], np.ndarray):
noise = np.random.normal(0, noise_scale, frame['observation.state'].shape)
frame['observation.state'] = frame['observation.state'] + noise
else:
# Handle tensor case
state = frame['observation.state'].cpu().numpy()
noise = np.random.normal(0, noise_scale, state.shape)
frame['observation.state'] = state + noise
# Process frame using flip_frame to ensure consistent format
processed_frame = flip_frame(frame)
episode4_frames.append((episode4_idx, processed_frame))
# Calculate time offset for smooth transitions
if 'timestamp' in episode2_frames[ep2_match_idx] and 'timestamp' in episode1_frames[ep1_match_idx]:
ep2_last_time = episode2_frames[ep2_match_idx]['timestamp']
ep1_start_time = episode1_frames[ep1_match_idx]['timestamp']
if isinstance(ep2_last_time, np.ndarray):
ep2_last_time = ep2_last_time.item()
elif hasattr(ep2_last_time, 'item'):
ep2_last_time = ep2_last_time.item()
if isinstance(ep1_start_time, np.ndarray):
ep1_start_time = ep1_start_time.item()
elif hasattr(ep1_start_time, 'item'):
ep1_start_time = ep1_start_time.item()
time_offset = ep2_last_time - ep1_start_time
else:
time_offset = 0.0
# Add frames from episode 1 after the matching point
for i in range(ep1_match_idx, len(episode1_frames)):
frame = copy.deepcopy(episode1_frames[i])
# Add noise to action and state if present
if 'action' in frame:
if isinstance(frame['action'], np.ndarray):
noise = np.random.normal(0, noise_scale, frame['action'].shape)
frame['action'] = frame['action'] + noise
else:
# Handle tensor case
action = frame['action'].cpu().numpy()
noise = np.random.normal(0, noise_scale, action.shape)
frame['action'] = action + noise
if 'observation.state' in frame:
if isinstance(frame['observation.state'], np.ndarray):
noise = np.random.normal(0, noise_scale, frame['observation.state'].shape)
frame['observation.state'] = frame['observation.state'] + noise
else:
# Handle tensor case
state = frame['observation.state'].cpu().numpy()
noise = np.random.normal(0, noise_scale, state.shape)
frame['observation.state'] = state + noise
# Adjust timestamp for smooth transition
if 'timestamp' in frame:
if isinstance(frame['timestamp'], np.ndarray):
frame['timestamp'] = frame['timestamp'] + time_offset
else:
frame['timestamp'] = frame['timestamp'] + time_offset
# Process frame using flip_frame to ensure consistent format
processed_frame = flip_frame(frame)
episode4_frames.append((episode4_idx, processed_frame))
print(f"Created Episode 3 with {len(episode3_frames)} frames")
print(f"Created Episode 4 with {len(episode4_frames)} frames")
return episode3_frames, episode4_frames
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment