Last active
October 13, 2023 19:26
-
-
Save dsevero/6140d918c0dfcdf6af3dca3cc8a261b2 to your computer and use it in GitHub Desktop.
Helper scripts to compute bits-per-dimension of images in a directory.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Helper script to compute BPD of images in a directory. | |
The script will never modify any files. | |
All conversions are done in memory on a copy. | |
Usage: python compute_bpd.py your_glob_pattern [extension] [colorspace] [psnr-check] | |
Args: | |
- extension: any valid PIL image extension (e.g., PNG, WebP, JPEG). | |
- colorspace: any valid PIL colorspace plus the lossless YCoCg. | |
- psnr-check (flag): if set, will compute PSNR of saved image with respect ot the original colorspace. | |
Examples: | |
# Compute BPD of all images in the directory | |
python compute_bpd.py your_glob | |
# Convert to PNG and compute bpd | |
python compute_bpd.py your_glob png | |
# Convert to webp and compute bpd | |
python compute_bpd.py your_glob webp | |
# Convert to lossless (with respect to RGB) YCoCg and save as PNG | |
python compute_bpd.py your_glob png YCoCg | |
# Convert to lossless (with respect to RGB) YCoCg, save as PNG, and check PSNR | |
python compute_bpd.py your_glob png YCoCg psnr-check | |
""" | |
import glob | |
import os | |
import sys | |
import math | |
import numpy as np | |
import io | |
from pathlib import Path | |
from multiprocessing import Pool | |
from PIL import Image | |
from functools import partial | |
from typing import Optional | |
# Script will ignore all extensions NOT listed here. | |
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".bmp"] | |
def compute_psnr(img1, img2, max_value=255): | |
mse = np.mean((img1 - img2) ** 2) | |
if mse == 0: | |
return float("inf") | |
psnr = 20 * math.log10(max_value / math.sqrt(mse)) | |
return psnr | |
def YCoCg_from_RGB(img_array): | |
R, G, B = img_array[..., 0], img_array[..., 1], img_array[..., 2] | |
diff = B - R | |
average = R + (diff >> 1) | |
temp = average | |
Co = diff | |
diff = temp - G | |
average = G + (diff >> 1) | |
Y = average | |
Cg = diff | |
return np.stack([Y, Co, Cg], axis=-1) | |
def RGB_from_YCoCg(img_array): | |
Y, Co, Cg = img_array[..., 0], img_array[..., 1], img_array[..., 2] | |
x = Y - (Cg >> 1) | |
yy = x + Cg | |
G = x | |
temp = yy | |
x = temp - (Co >> 1) | |
B = x + Co | |
R = x | |
return np.stack([R, G, B], axis=-1) | |
def is_image_file(filename): | |
# Check if a file has an image extension (you can add more extensions if needed) | |
return any(filename.suffix == ext for ext in IMAGE_EXTENSIONS) | |
def compute_bpd( | |
file_path: Path, | |
extension: Optional[str], | |
colorspace: Optional[str], | |
psnr_check: bool, | |
per_channel: bool, | |
): | |
use_file_ext = extension is None or str(file_path.suffix).lower() == "." + extension | |
if use_file_ext and colorspace is None: | |
assert per_channel == False | |
return compute_bpd_from_file(file_path) | |
else: | |
assert extension is not None | |
return convert_and_compute_bpd_with_pil( | |
file_path, extension, colorspace, psnr_check, per_channel | |
) | |
def convert_to_colorspace(img: Image.Image, colorspace: str): | |
if colorspace == "YCoCg": | |
# PNG will think this is RGB | |
return Image.fromarray(YCoCg_from_RGB(np.array(img.convert("RGB")))) | |
else: | |
return img.convert(colorspace) | |
def convert_and_compute_bpd_with_pil( | |
file_path: Path, | |
extension: str, | |
colorspace: Optional[str], | |
psnr_check: bool, | |
per_channel: bool, | |
): | |
def f(img, channel=None): | |
with io.BytesIO() as byte_stream: | |
if colorspace is not None and colorspace != img.mode: | |
img_conv = convert_to_colorspace(img, colorspace) | |
colorspace_msg = f"{colorspace} (converted)" | |
else: | |
img_conv = img | |
colorspace_msg = f"{img.mode} (from file)" | |
if channel is not None: | |
img_conv = img_conv.getchannel(channel) | |
img_conv.save(byte_stream, format=extension, lossless=True, optimize=True) | |
bytes = byte_stream.getvalue() | |
bits = len(bytes) * 8 | |
channels = len(img_conv.getbands()) | |
dims = math.prod(img.size) * channels | |
if per_channel: | |
assert channels == 1 | |
bpd = bits / dims | |
print( | |
f"{bpd: .2f} bpd ({file_path}) -> .{extension} w/ {colorspace_msg}", | |
end="", | |
) | |
if psnr_check: | |
assert img.size == img_conv.size | |
arr = np.array(img) | |
img_conv_dec = Image.open(byte_stream) | |
# TODO dsevero: need to make YCoCg a proper PIL plugin. | |
if colorspace == "YCoCg" and img.mode == "RGB": | |
arr_conv = RGB_from_YCoCg(np.array(img_conv_dec)) | |
else: | |
arr_conv = np.array(img_conv_dec.convert(img.mode)) | |
psnr = compute_psnr(arr, arr_conv) | |
print(f" PSNR={psnr:.2f}") | |
else: | |
print("") | |
return bpd | |
img = Image.open(file_path) | |
if per_channel: | |
bpd = dict() | |
for channel in range(len(img.getbands())): | |
print(f"Channel {channel}: ", end="") | |
bpd[channel] = f(img, channel) | |
else: | |
bpd = f(img) | |
return bpd | |
def compute_bpd_from_file(file_path: Path): | |
img = Image.open(file_path) | |
bits = os.path.getsize(file_path) * 8 | |
channels = len(img.getbands()) | |
dims = math.prod(img.size) * channels | |
bpd = bits / dims | |
print(f"{bpd: .2f} bpd ({file_path}) ") | |
return bpd | |
def main(): | |
# Check if the directory path is provided as a command line argument | |
directory_path = sys.argv[1] | |
extension = None if len(sys.argv) < 3 else sys.argv[2].lower() | |
colorspace = None if len(sys.argv) < 4 else sys.argv[3] | |
psnr_check = False if len(sys.argv) < 5 else ("psnr-check" in sys.argv[4:]) | |
per_channel = False if len(sys.argv) < 5 else ("per-channel" in sys.argv[4:]) | |
# List all files in the directory | |
files = [Path(file_path) for file_path in glob.glob(directory_path, recursive=True)] | |
files = [ | |
file_path | |
for file_path in files | |
if file_path.is_file() and is_image_file(file_path) | |
] | |
# Loop over the files and process images | |
with Pool(processes=os.cpu_count()) as pool: | |
f = partial( | |
compute_bpd, | |
extension=extension, | |
colorspace=colorspace, | |
psnr_check=psnr_check, | |
per_channel=per_channel, | |
) | |
bpds = pool.map(f, files) | |
print("-----------------------------------------------------") | |
if per_channel: | |
# TODO will break if some images have different extensions | |
channels = list(bpds[0].keys()) | |
print(f"Found {len(bpds)} images") | |
avg_bpd = 0 | |
for c in channels: | |
avg_bpd_channel = 0 | |
for bpd in bpds: | |
avg_bpd_channel += bpd[c] | |
avg_bpd_channel /= len(bpds) | |
avg_bpd += avg_bpd_channel | |
print(f"Channel {c}: {avg_bpd_channel:.2f} bpd (Average)") | |
avg_bpd /= len(channels) | |
print(f"BPD all channels: {avg_bpd}") | |
else: | |
avg_bpd = sum(bpds) / len(bpds) | |
print(f"Found {len(bpds)} images: {avg_bpd:.2f} bpd (Average)") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment