Last active
October 11, 2023 03:16
-
-
Save dsevero/b4c945f0b4a020c6f876920e0e74b608 to your computer and use it in GitHub Desktop.
Helper script to patchify images in a directory
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Usage: python patchify_padded.py 'my_dir/*.png' output_dir 32 | |
This script: | |
1) Loads all images matching the glob 'my_dir/*.png' | |
2) Padds them with zeros such that the height and width are divisible by the specified patch size (in this case, 32) | |
3) Saves them to disk in parallel to the output directory (output_dir) | |
''' | |
from PIL import Image | |
from pathlib import Path | |
from multiprocessing import Pool | |
import numpy as np | |
import os | |
import argparse | |
import glob | |
# Function to save a single patch | |
def save_patch(args): | |
patch, patch_filename = args | |
patch_image = Image.fromarray(patch) | |
patch_image.save(patch_filename) | |
if __name__ == '__main__': | |
# Create a command-line argument parser | |
parser = argparse.ArgumentParser(description="Partition an image into patches and save them.") | |
parser.add_argument("image_glob", help="Glob pattern to match input image files") | |
parser.add_argument("output_directory", help="Directory to save the patches") | |
parser.add_argument("patch_size", help="Size of patches", type=int) | |
args = parser.parse_args() | |
# Load the image using PIL | |
image_paths = glob.glob(args.image_glob) | |
# Define the size of the patches | |
patch_size = args.patch_size | |
# Create the output directory if it doesn't exist | |
output_dir = Path(args.output_directory) | |
os.makedirs(args.output_directory, exist_ok=True) | |
# Create a multiprocessing pool | |
pool = Pool() | |
# Use multiprocessing to save the patches in parallel | |
for path in image_paths: | |
print(f"Patchifying {path} ", end='') | |
# Convert the PIL image to a NumPy array | |
image_array = np.array(Image.open(path)) | |
# Get the dimensions of the image | |
height, width, channels = image_array.shape | |
print(f'... Original shape: ({height}, {width})', end='') | |
# Calculate the padding required to make the dimensions divisible by patch_size | |
pad_height = (patch_size - (height % patch_size)) % patch_size | |
pad_width = (patch_size - (width % patch_size)) % patch_size | |
# Apply padding to the image | |
image_array = np.pad(image_array, ((0, pad_height), (0, pad_width), (0, 0)), mode='constant') | |
# Get the new dimensions after padding | |
new_height, new_width, channels = image_array.shape | |
print(f' ... Padded shape: ({new_height}, {new_width})', end='') | |
# Calculate the number of patches in both dimensions | |
num_patches_height = new_height // patch_size | |
num_patches_width = new_width // patch_size | |
assert new_height % patch_size == 0 | |
assert new_width % patch_size == 0 | |
# Create the output dir for this image | |
image_patches_dir = output_dir / Path(path).stem / str(patch_size) | |
os.makedirs(image_patches_dir, exist_ok=True) | |
counter = 0 | |
for i in range(num_patches_height): | |
for j in range(num_patches_width): | |
# Calculate the coordinates for the current patch | |
top = i * patch_size | |
left = j * patch_size | |
bottom = (i + 1) * patch_size | |
right = (j + 1) * patch_size | |
# Split the image into patches using np.split | |
patch = image_array[top:bottom, left:right] | |
# Define a filename for the patch | |
patch_filename = image_patches_dir / f"patch_{counter}.png" | |
# Use multiprocessing to save the patch | |
pool.apply_async(save_patch, args=((patch, patch_filename),)) | |
counter += 1 | |
print(f' ... Created {counter} patches') | |
# Close the multiprocessing pool and wait for all processes to complete | |
pool.close() | |
pool.join() | |
print("Patches saved successfully.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment