Last active
January 3, 2023 21:25
-
-
Save maxim04/0ccc2d69489b4cb9717f18dbf6e615cf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#@markdown Run to generate a grid of preview images from the last saved weights. | |
import os | |
import matplotlib.pyplot as plt | |
import matplotlib.image as mpimg | |
weights_folder = os.getenv('OUTPUT_DIR') | |
folders = sorted([f for f in os.listdir(weights_folder) if f != "0"], key=lambda x: int(x)) | |
row = len(folders) | |
col = len(os.listdir(os.path.join(weights_folder, folders[0], "samples"))) | |
scale = 4 | |
fig, axes = plt.subplots(row, col, figsize=(col*scale, row*scale), gridspec_kw={'hspace': 0, 'wspace': 0}) | |
for i, folder in enumerate(folders): | |
folder_path = os.path.join(weights_folder, folder) | |
image_folder = os.path.join(folder_path, "samples") | |
images = [f for f in os.listdir(image_folder)] | |
for j, image in enumerate(images): | |
if row == 1: | |
currAxes = axes[j] | |
else: | |
currAxes = axes[i, j] | |
if i == 0: | |
currAxes.set_title(f"Image {j}") | |
if j == 0: | |
currAxes.text(-0.1, 0.5, folder, rotation=0, va='center', ha='center', transform=currAxes.transAxes) | |
image_path = os.path.join(image_folder, image) | |
img = mpimg.imread(image_path) | |
currAxes.imshow(img, cmap='gray') | |
currAxes.axis('off') | |
plt.tight_layout() | |
plt.savefig('grid.png', dpi=72) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment