Skip to content

Instantly share code, notes, and snippets.

@macleginn
Created June 2, 2026 11:28
Show Gist options
  • Select an option

  • Save macleginn/6168f47af16959c3037b1663f4005ba2 to your computer and use it in GitHub Desktop.

Select an option

Save macleginn/6168f47af16959c3037b1663f4005ba2 to your computer and use it in GitHub Desktop.
A script for applying reservoir sampling to the Dolma dataset
import random
import gzip
import json
import os
import requests
import pickle
from typing import List, Dict, Any, Tuple
from datetime import datetime
def reservoir_sampling_with_checkpoint(
url_list: List[str],
sample_size: int = 100000,
random_seed: int = 42,
checkpoint_file: str = "reservoir_checkpoint.pkl"
) -> List[Dict[str, Any]]:
"""
Perform reservoir sampling with a single checkpoint that gets updated after each file.
Args:
url_list: List of URLs to download gzipped jsonl files from
sample_size: Number of items to sample (default: 100000)
random_seed: Random seed for reproducibility (default: 42)
checkpoint_file: Path to the checkpoint file (default: "reservoir_checkpoint.pkl")
Returns:
List of sampled items
"""
random.seed(random_seed)
# Initialize or load from checkpoint
if os.path.exists(checkpoint_file):
print(f"Loading checkpoint from: {checkpoint_file}")
with open(checkpoint_file, 'rb') as f:
checkpoint_data = pickle.load(f)
reservoir_raw = checkpoint_data['reservoir']
items_seen = checkpoint_data['items_seen']
last_file_idx = checkpoint_data['last_file_idx']
start_idx = last_file_idx + 1
print(f"Resumed: {len(reservoir_raw)} items in reservoir, {items_seen} total items seen")
print(f"Last processed file: {url_list[last_file_idx] if last_file_idx < len(url_list) else 'unknown'}")
print(f"Starting from file index {start_idx}")
else:
# Start fresh
reservoir_raw = []
items_seen = 0
start_idx = 0
print("Starting fresh sampling")
# Process files
for url_idx in range(start_idx, len(url_list)):
url = url_list[url_idx]
print(f"\nProcessing file {url_idx + 1}/{len(url_list)}: {url}")
temp_filename = f"temp_file_{url_idx}.jsonl.gz"
file_items = 0
try:
# Download file
print(f" Downloading...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(temp_filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
# Process the gzipped jsonl file
print(f" Processing...")
with gzip.open(temp_filename, 'rt', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line: # Skip empty lines
items_seen += 1
file_items += 1
# Reservoir sampling algorithm
if len(reservoir_raw) < sample_size:
# Reservoir not full yet, add raw line
reservoir_raw.append(line)
else:
# Randomly decide whether to include this item
j = random.randint(0, items_seen - 1)
if j < sample_size:
reservoir_raw[j] = line
if file_items % 100000 == 0:
print(f" Processed {file_items} items from this file ({items_seen} total)")
print(f" Finished processing {file_items} items from this file")
except requests.RequestException as e:
print(f" Error downloading {url}: {e}")
except Exception as e:
print(f" Error processing {url}: {e}")
finally:
# Clean up: delete the temporary file
if os.path.exists(temp_filename):
os.remove(temp_filename)
print(f" Cleaned up temporary file")
# Save/update checkpoint after each file
checkpoint_data = {
'reservoir': reservoir_raw,
'items_seen': items_seen,
'last_file_idx': url_idx,
'sample_size': sample_size,
'random_seed': random_seed,
'timestamp': datetime.now().isoformat(),
'last_url': url,
'files_processed': url_idx + 1,
'total_files': len(url_list)
}
# Save to a temporary file first, then rename (atomic operation)
temp_checkpoint = checkpoint_file + '.tmp'
with open(temp_checkpoint, 'wb') as f:
pickle.dump(checkpoint_data, f)
# Atomic rename to avoid corruption if interrupted
os.replace(temp_checkpoint, checkpoint_file)
print(f" Updated checkpoint ({url_idx + 1}/{len(url_list)} files processed)")
print(f" Current reservoir size: {len(reservoir_raw)}")
print(f" Total items seen: {items_seen}")
# Parse all selected lines at the end
print("\nParsing selected items...")
reservoir = []
parse_errors = 0
for idx, raw_line in enumerate(reservoir_raw):
try:
item = json.loads(raw_line)
reservoir.append(item)
except json.JSONDecodeError as e:
parse_errors += 1
print(f"Warning: Failed to parse line {idx}: {e}")
if (idx + 1) % 10000 == 0:
print(f" Parsed {idx + 1}/{len(reservoir_raw)} items")
print(f"\nSampling complete. Final sample size: {len(reservoir)}")
if parse_errors > 0:
print(f" JSON parse errors: {parse_errors}")
# Optionally, you can delete the checkpoint file after successful completion
# os.remove(checkpoint_file)
return reservoir
if __name__ == "__main__":
with open("v1_7.txt", "r") as inp:
url_list = [line.strip() for line in inp if "starcoder" not in line]
# Run sampling with single checkpoint
sample = reservoir_sampling_with_checkpoint(
url_list,
sample_size=200000,
random_seed=43, # Second iteration
checkpoint_file="my_sampling_checkpoint.pkl"
)
# Save the final sample
with open("sample.jsonl", "w", encoding="utf-8") as f:
for item in sample:
f.write(json.dumps(item) + "\n")
print(f"\nSample saved to sample.jsonl")
# Optionally remove checkpoint after successful completion
if os.path.exists("my_sampling_checkpoint.pkl"):
os.remove("my_sampling_checkpoint.pkl")
print("Removed checkpoint file")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment