Skip to content

Instantly share code, notes, and snippets.

@kenenbek
Created May 30, 2025 12:24
Show Gist options
  • Save kenenbek/dfcddfcf4f9b556bf3f593f2af8e79a4 to your computer and use it in GitHub Desktop.
Save kenenbek/dfcddfcf4f9b556bf3f593f2af8e79a4 to your computer and use it in GitHub Desktop.
import unittest
import os
import shutil
import file_distributor
class TestFileDistributor(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Clean up output directory before each test run
if os.path.exists(file_distributor.OUTPUT_DIR):
shutil.rmtree(file_distributor.OUTPUT_DIR)
os.makedirs(file_distributor.OUTPUT_DIR, exist_ok=True)
@classmethod
def tearDownClass(cls):
# Clean up after tests
if os.path.exists(file_distributor.OUTPUT_DIR):
shutil.rmtree(file_distributor.OUTPUT_DIR)
def setUp(self):
# Clean output before each test
if os.path.exists(file_distributor.OUTPUT_DIR):
shutil.rmtree(file_distributor.OUTPUT_DIR)
os.makedirs(file_distributor.OUTPUT_DIR, exist_ok=True)
def test_worker_folders_created(self):
file_distributor.distribute_files(
num_labelers=4,
num_copies=2,
source_dir=file_distributor.SOURCE_DIR,
output_dir=file_distributor.OUTPUT_DIR
)
for i in range(4):
self.assertTrue(os.path.isdir(os.path.join(file_distributor.OUTPUT_DIR, f'worker_{i}')))
def test_files_distributed_to_workers(self):
file_distributor.distribute_files(
num_labelers=3,
num_copies=2,
source_dir=file_distributor.SOURCE_DIR,
output_dir=file_distributor.OUTPUT_DIR
)
files = sorted([
f for f in os.listdir(file_distributor.SOURCE_DIR)
if os.path.isfile(os.path.join(file_distributor.SOURCE_DIR, f))
])
all_worker_files = set()
for i in range(3):
worker_dir = os.path.join(file_distributor.OUTPUT_DIR, f'worker_{i}')
worker_files = os.listdir(worker_dir)
all_worker_files.update(worker_files)
for f in files:
self.assertIn(f, all_worker_files)
def test_each_file_copied_to_num_copies_workers(self):
num_labelers = 17
num_copies = 3
file_distributor.distribute_files(
num_labelers=num_labelers,
num_copies=num_copies,
source_dir=file_distributor.SOURCE_DIR,
output_dir=file_distributor.OUTPUT_DIR
)
files = sorted([
f for f in os.listdir(file_distributor.SOURCE_DIR)
if os.path.isfile(os.path.join(file_distributor.SOURCE_DIR, f))
])
for f in files:
count = 0
for i in range(num_labelers):
worker_dir = os.path.join(file_distributor.OUTPUT_DIR, f'worker_{i}')
if f in os.listdir(worker_dir):
count += 1
self.assertEqual(count, num_copies, f"File {f} is in {count} workers, expected {num_copies}")
def test_workers_have_approx_same_amount_of_files(self):
file_distributor.distribute_files(
num_labelers=4,
num_copies=2,
source_dir=file_distributor.SOURCE_DIR,
output_dir=file_distributor.OUTPUT_DIR
)
counts = []
for i in range(4):
worker_dir = os.path.join(file_distributor.OUTPUT_DIR, f'worker_{i}')
counts.append(len(os.listdir(worker_dir)))
self.assertLessEqual(max(counts) - min(counts), 1)
def test_num_labelers_less_than_num_copies(self):
# num_labelers < num_copies
file_distributor.distribute_files(
num_labelers=2,
num_copies=3,
source_dir=file_distributor.SOURCE_DIR,
output_dir=file_distributor.OUTPUT_DIR
)
files = sorted([
f for f in os.listdir(file_distributor.SOURCE_DIR)
if os.path.isfile(os.path.join(file_distributor.SOURCE_DIR, f))
])
for f in files:
count = 0
for i in range(2):
worker_dir = os.path.join(file_distributor.OUTPUT_DIR, f'worker_{i}')
if f in os.listdir(worker_dir):
count += 1
self.assertEqual(count, 2, f"File {f} should be in 2 workers (max possible), got {count}")
def test_single_labeler(self):
# Single labeler
file_distributor.distribute_files(
num_labelers=1,
num_copies=3,
source_dir=file_distributor.SOURCE_DIR,
output_dir=file_distributor.OUTPUT_DIR
)
files = sorted([
f for f in os.listdir(file_distributor.SOURCE_DIR)
if os.path.isfile(os.path.join(file_distributor.SOURCE_DIR, f))
])
for f in files:
worker_dir = os.path.join(file_distributor.OUTPUT_DIR, 'worker_0')
self.assertIn(f, os.listdir(worker_dir))
def test_single_copy(self):
# Single copy
file_distributor.distribute_files(
num_labelers=3,
num_copies=1,
source_dir=file_distributor.SOURCE_DIR,
output_dir=file_distributor.OUTPUT_DIR
)
files = sorted([
f for f in os.listdir(file_distributor.SOURCE_DIR)
if os.path.isfile(os.path.join(file_distributor.SOURCE_DIR, f))
])
for f in files:
count = 0
for i in range(3):
worker_dir = os.path.join(file_distributor.OUTPUT_DIR, f'worker_{i}')
if f in os.listdir(worker_dir):
count += 1
self.assertEqual(count, 1, f"File {f} should be in 1 worker, got {count}")
def test_empty_source_directory(self):
# Create an empty directory for this test
empty_dir = 'empty_source_dir'
os.makedirs(empty_dir, exist_ok=True)
file_distributor.distribute_files(
num_labelers=3,
num_copies=2,
source_dir=empty_dir,
output_dir=file_distributor.OUTPUT_DIR
)
for i in range(3):
worker_dir = os.path.join(file_distributor.OUTPUT_DIR, f'worker_{i}')
self.assertTrue(os.path.isdir(worker_dir))
self.assertEqual(len(os.listdir(worker_dir)), 0)
shutil.rmtree(empty_dir)
def test_num_labelers_greater_than_num_files(self):
# Use only 2 files for this test
temp_dir = 'temp_source_dir'
os.makedirs(temp_dir, exist_ok=True)
with open(os.path.join(temp_dir, 'a.txt'), 'w') as f:
f.write('A')
with open(os.path.join(temp_dir, 'b.txt'), 'w') as f:
f.write('B')
file_distributor.distribute_files(
num_labelers=5,
num_copies=1,
source_dir=temp_dir,
output_dir=file_distributor.OUTPUT_DIR
)
# Only two workers should have files
non_empty = 0
for i in range(5):
worker_dir = os.path.join(file_distributor.OUTPUT_DIR, f'worker_{i}')
if len(os.listdir(worker_dir)) > 0:
non_empty += 1
self.assertEqual(non_empty, 2)
shutil.rmtree(temp_dir)
def test_no_duplicate_files_in_worker(self):
file_distributor.distribute_files(
num_labelers=4,
num_copies=3,
source_dir=file_distributor.SOURCE_DIR,
output_dir=file_distributor.OUTPUT_DIR
)
for i in range(4):
worker_dir = os.path.join(file_distributor.OUTPUT_DIR, f'worker_{i}')
files = os.listdir(worker_dir)
self.assertEqual(len(files), len(set(files)), f"Duplicates found in worker_{i}")
def test_nonexistent_source_directory(self):
with self.assertRaises(FileNotFoundError):
file_distributor.distribute_files(
num_labelers=3,
num_copies=2,
source_dir='nonexistent_dir',
output_dir=file_distributor.OUTPUT_DIR
)
def test_output_folders_created_even_if_no_files(self):
empty_dir = 'empty_source_dir2'
os.makedirs(empty_dir, exist_ok=True)
file_distributor.distribute_files(
num_labelers=4,
num_copies=2,
source_dir=empty_dir,
output_dir=file_distributor.OUTPUT_DIR
)
for i in range(4):
worker_dir = os.path.join(file_distributor.OUTPUT_DIR, f'worker_{i}')
self.assertTrue(os.path.isdir(worker_dir))
shutil.rmtree(empty_dir)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment