Last active
August 24, 2024 10:55
-
-
Save Ammar-Ishfaq/4fe597246a4fb9a766c72b07da257e60 to your computer and use it in GitHub Desktop.
train_val_test_split.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### Python script to split a labeled image dataset into Train, Validation, and Test folders. | |
# Author: Evan Juras, EJ Technology Consultants | |
# Date: 4/10/21 | |
# Randomly splits images to 80% train, 10% validation, and 10% test, and moves them to their respective folders. | |
# This script is intended to be used in the TFLite Object Detection Colab notebook here: | |
# https://colab.research.google.com/github/EdjeElectronics/TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi/blob/master/Train_TFLite2_Object_Detction_Model.ipynb | |
from pathlib import Path | |
import random | |
import os | |
import sys | |
# Define paths to image folders | |
image_path = '/kaggle/working/images/all' | |
train_path = '/kaggle/working/images/train' | |
val_path = '/kaggle/working/images/validation' | |
test_path = '/kaggle/working/images/test' | |
# Get list of all images | |
jpeg_file_list = [path for path in Path(image_path).rglob('*.jpeg')] | |
jpg_file_list = [path for path in Path(image_path).rglob('*.jpg')] | |
png_file_list = [path for path in Path(image_path).rglob('*.png')] | |
bmp_file_list = [path for path in Path(image_path).rglob('*.bmp')] | |
if sys.platform == 'linux': | |
JPEG_file_list = [path for path in Path(image_path).rglob('*.JPEG')] | |
JPG_file_list = [path for path in Path(image_path).rglob('*.JPG')] | |
file_list = jpg_file_list + JPG_file_list + png_file_list + bmp_file_list + JPEG_file_list + jpeg_file_list | |
else: | |
file_list = jpg_file_list + png_file_list + bmp_file_list + jpeg_file_list | |
file_num = len(file_list) | |
print('Total images: %d' % file_num) | |
# Determine number of files to move to each folder | |
train_percent = 0.8 # 80% of the files go to train | |
val_percent = 0.1 # 10% go to validation | |
test_percent = 0.1 # 10% go to test | |
train_num = int(file_num*train_percent) | |
val_num = int(file_num*val_percent) | |
test_num = file_num - train_num - val_num | |
print('Images moving to train: %d' % train_num) | |
print('Images moving to validation: %d' % val_num) | |
print('Images moving to test: %d' % test_num) | |
# Select 80% of files randomly and move them to train folder | |
for i in range(train_num): | |
move_me = random.choice(file_list) | |
fn = move_me.name | |
base_fn = move_me.stem | |
parent_path = move_me.parent | |
xml_fn = base_fn + '.xml' | |
os.rename(move_me, train_path+'/'+fn) | |
os.rename(os.path.join(parent_path,xml_fn),os.path.join(train_path,xml_fn)) | |
file_list.remove(move_me) | |
# Select 10% of remaining files and move them to validation folder | |
for i in range(val_num): | |
move_me = random.choice(file_list) | |
fn = move_me.name | |
base_fn = move_me.stem | |
parent_path = move_me.parent | |
xml_fn = base_fn + '.xml' | |
os.rename(move_me, val_path+'/'+fn) | |
os.rename(os.path.join(parent_path,xml_fn),os.path.join(val_path,xml_fn)) | |
file_list.remove(move_me) | |
# Move remaining files to test folder | |
for i in range(test_num): | |
move_me = random.choice(file_list) | |
fn = move_me.name | |
base_fn = move_me.stem | |
parent_path = move_me.parent | |
xml_fn = base_fn + '.xml' | |
os.rename(move_me, test_path+'/'+fn) | |
os.rename(os.path.join(parent_path,xml_fn),os.path.join(test_path,xml_fn)) | |
file_list.remove(move_me) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment