Created
June 2, 2026 11:28
-
-
Save macleginn/6168f47af16959c3037b1663f4005ba2 to your computer and use it in GitHub Desktop.
A script for applying reservoir sampling to the Dolma dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| import gzip | |
| import json | |
| import os | |
| import requests | |
| import pickle | |
| from typing import List, Dict, Any, Tuple | |
| from datetime import datetime | |
| def reservoir_sampling_with_checkpoint( | |
| url_list: List[str], | |
| sample_size: int = 100000, | |
| random_seed: int = 42, | |
| checkpoint_file: str = "reservoir_checkpoint.pkl" | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Perform reservoir sampling with a single checkpoint that gets updated after each file. | |
| Args: | |
| url_list: List of URLs to download gzipped jsonl files from | |
| sample_size: Number of items to sample (default: 100000) | |
| random_seed: Random seed for reproducibility (default: 42) | |
| checkpoint_file: Path to the checkpoint file (default: "reservoir_checkpoint.pkl") | |
| Returns: | |
| List of sampled items | |
| """ | |
| random.seed(random_seed) | |
| # Initialize or load from checkpoint | |
| if os.path.exists(checkpoint_file): | |
| print(f"Loading checkpoint from: {checkpoint_file}") | |
| with open(checkpoint_file, 'rb') as f: | |
| checkpoint_data = pickle.load(f) | |
| reservoir_raw = checkpoint_data['reservoir'] | |
| items_seen = checkpoint_data['items_seen'] | |
| last_file_idx = checkpoint_data['last_file_idx'] | |
| start_idx = last_file_idx + 1 | |
| print(f"Resumed: {len(reservoir_raw)} items in reservoir, {items_seen} total items seen") | |
| print(f"Last processed file: {url_list[last_file_idx] if last_file_idx < len(url_list) else 'unknown'}") | |
| print(f"Starting from file index {start_idx}") | |
| else: | |
| # Start fresh | |
| reservoir_raw = [] | |
| items_seen = 0 | |
| start_idx = 0 | |
| print("Starting fresh sampling") | |
| # Process files | |
| for url_idx in range(start_idx, len(url_list)): | |
| url = url_list[url_idx] | |
| print(f"\nProcessing file {url_idx + 1}/{len(url_list)}: {url}") | |
| temp_filename = f"temp_file_{url_idx}.jsonl.gz" | |
| file_items = 0 | |
| try: | |
| # Download file | |
| print(f" Downloading...") | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| with open(temp_filename, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| # Process the gzipped jsonl file | |
| print(f" Processing...") | |
| with gzip.open(temp_filename, 'rt', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: # Skip empty lines | |
| items_seen += 1 | |
| file_items += 1 | |
| # Reservoir sampling algorithm | |
| if len(reservoir_raw) < sample_size: | |
| # Reservoir not full yet, add raw line | |
| reservoir_raw.append(line) | |
| else: | |
| # Randomly decide whether to include this item | |
| j = random.randint(0, items_seen - 1) | |
| if j < sample_size: | |
| reservoir_raw[j] = line | |
| if file_items % 100000 == 0: | |
| print(f" Processed {file_items} items from this file ({items_seen} total)") | |
| print(f" Finished processing {file_items} items from this file") | |
| except requests.RequestException as e: | |
| print(f" Error downloading {url}: {e}") | |
| except Exception as e: | |
| print(f" Error processing {url}: {e}") | |
| finally: | |
| # Clean up: delete the temporary file | |
| if os.path.exists(temp_filename): | |
| os.remove(temp_filename) | |
| print(f" Cleaned up temporary file") | |
| # Save/update checkpoint after each file | |
| checkpoint_data = { | |
| 'reservoir': reservoir_raw, | |
| 'items_seen': items_seen, | |
| 'last_file_idx': url_idx, | |
| 'sample_size': sample_size, | |
| 'random_seed': random_seed, | |
| 'timestamp': datetime.now().isoformat(), | |
| 'last_url': url, | |
| 'files_processed': url_idx + 1, | |
| 'total_files': len(url_list) | |
| } | |
| # Save to a temporary file first, then rename (atomic operation) | |
| temp_checkpoint = checkpoint_file + '.tmp' | |
| with open(temp_checkpoint, 'wb') as f: | |
| pickle.dump(checkpoint_data, f) | |
| # Atomic rename to avoid corruption if interrupted | |
| os.replace(temp_checkpoint, checkpoint_file) | |
| print(f" Updated checkpoint ({url_idx + 1}/{len(url_list)} files processed)") | |
| print(f" Current reservoir size: {len(reservoir_raw)}") | |
| print(f" Total items seen: {items_seen}") | |
| # Parse all selected lines at the end | |
| print("\nParsing selected items...") | |
| reservoir = [] | |
| parse_errors = 0 | |
| for idx, raw_line in enumerate(reservoir_raw): | |
| try: | |
| item = json.loads(raw_line) | |
| reservoir.append(item) | |
| except json.JSONDecodeError as e: | |
| parse_errors += 1 | |
| print(f"Warning: Failed to parse line {idx}: {e}") | |
| if (idx + 1) % 10000 == 0: | |
| print(f" Parsed {idx + 1}/{len(reservoir_raw)} items") | |
| print(f"\nSampling complete. Final sample size: {len(reservoir)}") | |
| if parse_errors > 0: | |
| print(f" JSON parse errors: {parse_errors}") | |
| # Optionally, you can delete the checkpoint file after successful completion | |
| # os.remove(checkpoint_file) | |
| return reservoir | |
| if __name__ == "__main__": | |
| with open("v1_7.txt", "r") as inp: | |
| url_list = [line.strip() for line in inp if "starcoder" not in line] | |
| # Run sampling with single checkpoint | |
| sample = reservoir_sampling_with_checkpoint( | |
| url_list, | |
| sample_size=200000, | |
| random_seed=43, # Second iteration | |
| checkpoint_file="my_sampling_checkpoint.pkl" | |
| ) | |
| # Save the final sample | |
| with open("sample.jsonl", "w", encoding="utf-8") as f: | |
| for item in sample: | |
| f.write(json.dumps(item) + "\n") | |
| print(f"\nSample saved to sample.jsonl") | |
| # Optionally remove checkpoint after successful completion | |
| if os.path.exists("my_sampling_checkpoint.pkl"): | |
| os.remove("my_sampling_checkpoint.pkl") | |
| print("Removed checkpoint file") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment