Last active
February 18, 2025 22:15
-
-
Save fepegar/8c5f5444dbd6b29a44c6fa14c070c4c9 to your computer and use it in GitHub Desktop.
Script to extract features from chest X-rays using RAD-DINO. Try `uv run run_rad_dino.py --help`.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = ">=3.12" | |
# dependencies = [ | |
# "einops", | |
# "loguru", | |
# "numpy", | |
# "pillow", | |
# "procex", | |
# "torch", | |
# "transformers", | |
# "typer", | |
# ] | |
# | |
# [tool.uv.sources] | |
# procex = { git = "https://github.com/fepegar/procex" } | |
# [tool.ruff.lint.isort] | |
# force-single-line = true | |
# /// | |
from __future__ import annotations | |
import enum | |
from functools import cache | |
from itertools import batched | |
from pathlib import Path | |
from typing import Annotated, Sequence | |
import numpy as np | |
import numpy.typing as npt | |
import SimpleITK as sitk | |
import torch | |
import typer | |
from einops import rearrange | |
from loguru import logger | |
from PIL import Image | |
from procex import functional as F | |
from procex.imgio import read_image | |
from rich import print | |
from rich.progress import ( | |
BarColumn, | |
MofNCompleteColumn, | |
Progress, | |
SpinnerColumn, | |
TextColumn, | |
TimeElapsedColumn, | |
) | |
from torch import nn | |
from tqdm.auto import tqdm | |
from transformers import AutoModel, BitImageProcessor | |
app = typer.Typer( | |
no_args_is_help=True, | |
) | |
@enum.unique | |
class Model(str, enum.Enum): | |
RAD_DINO = "rad-dino" | |
RAD_DINO_MAIRA_2 = "rad-dino-maira-2" | |
@app.command() | |
def main( | |
input: Annotated[ | |
Path, | |
typer.Argument( | |
help=( | |
"Input image(s). If it is a DICOM file, it will be temporarily" | |
" converted to PNG. If it is a .txt file, it must contain paths" | |
"to images, one per line." | |
), | |
show_default=False, | |
exists=True, | |
file_okay=True, | |
dir_okay=False, | |
writable=False, | |
readable=True, | |
resolve_path=True, | |
), | |
], | |
features_path: Annotated[ | |
Path | None, | |
typer.Option( | |
"--features", | |
"-f", | |
help="Output features file.", | |
show_default=False, | |
exists=True, | |
file_okay=True, | |
dir_okay=False, | |
writable=True, | |
readable=False, | |
), | |
] = None, | |
out_dir: Annotated[ | |
Path | None, | |
typer.Option( | |
"--out-dir", | |
help="Output directory for features files.", | |
show_default=False, | |
exists=False, | |
file_okay=False, | |
dir_okay=True, | |
writable=True, | |
readable=False, | |
), | |
] = None, | |
in_dir: Annotated[ | |
Path | None, | |
typer.Option( | |
"--in-dir", | |
help=( | |
"If passed, the path of the output relative to the output" | |
" directory will be the same as the input path relative to" | |
" this directory." | |
), | |
show_default=False, | |
exists=True, | |
file_okay=False, | |
dir_okay=True, | |
writable=False, | |
readable=True, | |
), | |
] = None, | |
model_name: Annotated[ | |
Model, | |
typer.Option( | |
"--model", | |
"-m", | |
help="Model to use.", | |
show_default=False, | |
), | |
] = Model.RAD_DINO, | |
batch_size: Annotated[ | |
int, | |
typer.Option( | |
"--batch-size", | |
"-b", | |
help="Batch size.", | |
show_default=False, | |
), | |
] = 1, | |
cls: Annotated[ | |
bool, | |
typer.Option( | |
help="Whether to include the CLS token.", | |
show_default=False, | |
), | |
] = True, | |
patch: Annotated[ | |
bool, | |
typer.Option( | |
help="Whether to include the patch tokens.", | |
show_default=False, | |
), | |
] = True, | |
) -> None: | |
input_paths = _get_input_paths(input) | |
output_paths = _get_output_paths(input_paths, features_path, out_dir, in_dir) | |
import sys | |
sys.exit() | |
device = _get_device() | |
with BarlessProgress() as progress: | |
task = progress.add_task("Loading model...", total=1) | |
model, processor = _get_model_and_processor(model_name.value, device) | |
progress.update(task, advance=1) | |
print(f'Running inference on device: "{device}"') | |
input_batches = batched(input_paths, batch_size) | |
output_batches = batched(output_paths, batch_size) | |
iterable = list(zip(input_batches, output_batches)) | |
_process_batches( | |
iterable, | |
model, | |
processor, | |
device, | |
cls=cls, | |
patch=patch, | |
) | |
def _get_input_paths( | |
input_path: Path, | |
) -> list[Path]: | |
if input_path.suffix == ".txt": | |
with input_path.open() as f: | |
input_paths = [Path(line.strip()) for line in f] | |
else: | |
input_paths = [input_path] | |
return input_paths | |
def _get_output_paths( | |
input_paths: list[Path], | |
features_path: Path | None, | |
out_dir: Path | None, | |
in_dir: Path | None, | |
) -> list[Path]: | |
if features_path is not None and out_dir is not None: | |
message = "You can only provide one of --features or --out-dir" | |
logger.error(message) | |
raise typer.Abort | |
elif features_path is not None: | |
output_paths = [features_path] | |
elif out_dir is not None: | |
if in_dir is None: | |
output_paths = [out_dir / p.with_suffix(".npz").name for p in input_paths] | |
else: | |
output_paths = [ | |
out_dir / p.relative_to(in_dir).with_suffix(".npz") for p in input_paths | |
] | |
print(output_paths) | |
return output_paths | |
def _process_batches( | |
in_paths_out_paths: list[tuple[tuple[Path, ...], tuple[Path, ...]]], | |
model: AutoModel, | |
processor: BitImageProcessor, | |
device: torch.device, | |
*, | |
cls: bool, | |
patch: bool, | |
): | |
message = "Processing batches..." | |
with BarProgress(transient=True) as progress: | |
task = progress.add_task(message, total=len(in_paths_out_paths)) | |
for inputs_batch, outputs_batch in in_paths_out_paths: | |
_process_batch( | |
inputs_batch, | |
outputs_batch, | |
model, | |
processor, | |
device, | |
save_cls=cls, | |
save_patch=patch, | |
) | |
progress.update(task, advance=1) | |
def _process_batch( | |
input_paths: Sequence[Path], | |
output_paths: Sequence[Path], | |
model: AutoModel, | |
processor: BitImageProcessor, | |
device: torch.device, | |
*, | |
save_cls: bool, | |
save_patch: bool, | |
): | |
if not save_cls and not save_patch: | |
message = "You must save at least one of the CLS token or the patch tokens." | |
logger.error(message) | |
raise typer.Abort | |
images = [_load_image(p) for p in input_paths] | |
cls_embeddings, patch_embeddings = _infer(images, model, processor, device) | |
zipped = zip(output_paths, cls_embeddings, patch_embeddings) | |
for features_path, cls_embedding, patch_embedding in zipped: | |
features_path.parent.mkdir(parents=True, exist_ok=True) | |
kwargs = {} | |
if save_cls: | |
kwargs["cls_embeddings"] = cls_embedding | |
if save_patch: | |
kwargs["patch_embeddings"] = patch_embedding | |
with features_path.open("wb") as f: | |
np.savez(f, **kwargs) | |
def _infer( | |
images: list[Image.Image], | |
model: nn.Module, | |
processor: BitImageProcessor, | |
device: torch.device, | |
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]: | |
processed = processor(images, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**processed) | |
cls_embeddings = outputs.pooler_output | |
flat_patch_embeddings = outputs.last_hidden_state[:, 1:] # first token is CLS | |
reshaped_patch_embeddings = _reshape_patch_embeddings( | |
flat_patch_embeddings, | |
image_size=processor.crop_size["height"], | |
patch_size=model.config.patch_size, | |
) | |
return cls_embeddings.cpu().numpy(), reshaped_patch_embeddings.cpu().numpy() | |
def _reshape_patch_embeddings( | |
flat_tokens: torch.Tensor, | |
*, | |
image_size: int, | |
patch_size: int, | |
) -> torch.Tensor: | |
"""Reshape flat list of patch tokens into a nice grid.""" | |
embeddings_size = image_size // patch_size | |
patches_grid = rearrange(flat_tokens, "b (h w) c -> b c h w", h=embeddings_size) | |
return patches_grid | |
@cache | |
def _get_model_and_processor( | |
model_name: str, | |
device: torch.device, | |
) -> tuple[AutoModel, BitImageProcessor]: | |
repo = f"microsoft/{model_name}" | |
model = AutoModel.from_pretrained(repo).to(device).eval() | |
processor = BitImageProcessor.from_pretrained(repo) | |
return model, processor | |
def _load_image(image_path: Path) -> Image.Image: | |
if image_path.suffix == ".dcm": | |
image = _load_dicom(image_path) | |
else: | |
image = Image.open(image_path) | |
return image | |
def _load_dicom(image_path: Path) -> Image.Image: | |
image = read_image(image_path) | |
image = F.enhance_contrast(image, num_bits=8) | |
array = sitk.GetArrayFromImage(image) | |
array = np.squeeze(array) | |
return Image.fromarray(array) | |
def _get_device() -> torch.device: | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
return device | |
class BarlessProgress(Progress): | |
def __init__(self, *args, **kwargs): | |
columns = [ | |
SpinnerColumn(), | |
TextColumn("[progress.description]{task.description}"), | |
TimeElapsedColumn(), | |
] | |
super().__init__(*columns, *args, **kwargs) | |
class BarProgress(Progress): | |
def __init__(self, *args, **kwargs): | |
columns = [ | |
TextColumn("[progress.description]{task.description}"), | |
BarColumn(), | |
MofNCompleteColumn(), | |
TimeElapsedColumn(), | |
] | |
super().__init__(*columns, *args, **kwargs) | |
if __name__ == "__main__": | |
app() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment