Skip to content

Instantly share code, notes, and snippets.

@fepegar
Last active February 18, 2025 22:15
Show Gist options
  • Save fepegar/8c5f5444dbd6b29a44c6fa14c070c4c9 to your computer and use it in GitHub Desktop.
Save fepegar/8c5f5444dbd6b29a44c6fa14c070c4c9 to your computer and use it in GitHub Desktop.
Script to extract features from chest X-rays using RAD-DINO. Try `uv run run_rad_dino.py --help`.
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "einops",
# "loguru",
# "numpy",
# "pillow",
# "procex",
# "torch",
# "transformers",
# "typer",
# ]
#
# [tool.uv.sources]
# procex = { git = "https://github.com/fepegar/procex" }
# [tool.ruff.lint.isort]
# force-single-line = true
# ///
from __future__ import annotations
import enum
from functools import cache
from itertools import batched
from pathlib import Path
from typing import Annotated, Sequence
import numpy as np
import numpy.typing as npt
import SimpleITK as sitk
import torch
import typer
from einops import rearrange
from loguru import logger
from PIL import Image
from procex import functional as F
from procex.imgio import read_image
from rich import print
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
)
from torch import nn
from tqdm.auto import tqdm
from transformers import AutoModel, BitImageProcessor
app = typer.Typer(
no_args_is_help=True,
)
@enum.unique
class Model(str, enum.Enum):
RAD_DINO = "rad-dino"
RAD_DINO_MAIRA_2 = "rad-dino-maira-2"
@app.command()
def main(
input: Annotated[
Path,
typer.Argument(
help=(
"Input image(s). If it is a DICOM file, it will be temporarily"
" converted to PNG. If it is a .txt file, it must contain paths"
"to images, one per line."
),
show_default=False,
exists=True,
file_okay=True,
dir_okay=False,
writable=False,
readable=True,
resolve_path=True,
),
],
features_path: Annotated[
Path | None,
typer.Option(
"--features",
"-f",
help="Output features file.",
show_default=False,
exists=True,
file_okay=True,
dir_okay=False,
writable=True,
readable=False,
),
] = None,
out_dir: Annotated[
Path | None,
typer.Option(
"--out-dir",
help="Output directory for features files.",
show_default=False,
exists=False,
file_okay=False,
dir_okay=True,
writable=True,
readable=False,
),
] = None,
in_dir: Annotated[
Path | None,
typer.Option(
"--in-dir",
help=(
"If passed, the path of the output relative to the output"
" directory will be the same as the input path relative to"
" this directory."
),
show_default=False,
exists=True,
file_okay=False,
dir_okay=True,
writable=False,
readable=True,
),
] = None,
model_name: Annotated[
Model,
typer.Option(
"--model",
"-m",
help="Model to use.",
show_default=False,
),
] = Model.RAD_DINO,
batch_size: Annotated[
int,
typer.Option(
"--batch-size",
"-b",
help="Batch size.",
show_default=False,
),
] = 1,
cls: Annotated[
bool,
typer.Option(
help="Whether to include the CLS token.",
show_default=False,
),
] = True,
patch: Annotated[
bool,
typer.Option(
help="Whether to include the patch tokens.",
show_default=False,
),
] = True,
) -> None:
input_paths = _get_input_paths(input)
output_paths = _get_output_paths(input_paths, features_path, out_dir, in_dir)
import sys
sys.exit()
device = _get_device()
with BarlessProgress() as progress:
task = progress.add_task("Loading model...", total=1)
model, processor = _get_model_and_processor(model_name.value, device)
progress.update(task, advance=1)
print(f'Running inference on device: "{device}"')
input_batches = batched(input_paths, batch_size)
output_batches = batched(output_paths, batch_size)
iterable = list(zip(input_batches, output_batches))
_process_batches(
iterable,
model,
processor,
device,
cls=cls,
patch=patch,
)
def _get_input_paths(
input_path: Path,
) -> list[Path]:
if input_path.suffix == ".txt":
with input_path.open() as f:
input_paths = [Path(line.strip()) for line in f]
else:
input_paths = [input_path]
return input_paths
def _get_output_paths(
input_paths: list[Path],
features_path: Path | None,
out_dir: Path | None,
in_dir: Path | None,
) -> list[Path]:
if features_path is not None and out_dir is not None:
message = "You can only provide one of --features or --out-dir"
logger.error(message)
raise typer.Abort
elif features_path is not None:
output_paths = [features_path]
elif out_dir is not None:
if in_dir is None:
output_paths = [out_dir / p.with_suffix(".npz").name for p in input_paths]
else:
output_paths = [
out_dir / p.relative_to(in_dir).with_suffix(".npz") for p in input_paths
]
print(output_paths)
return output_paths
def _process_batches(
in_paths_out_paths: list[tuple[tuple[Path, ...], tuple[Path, ...]]],
model: AutoModel,
processor: BitImageProcessor,
device: torch.device,
*,
cls: bool,
patch: bool,
):
message = "Processing batches..."
with BarProgress(transient=True) as progress:
task = progress.add_task(message, total=len(in_paths_out_paths))
for inputs_batch, outputs_batch in in_paths_out_paths:
_process_batch(
inputs_batch,
outputs_batch,
model,
processor,
device,
save_cls=cls,
save_patch=patch,
)
progress.update(task, advance=1)
def _process_batch(
input_paths: Sequence[Path],
output_paths: Sequence[Path],
model: AutoModel,
processor: BitImageProcessor,
device: torch.device,
*,
save_cls: bool,
save_patch: bool,
):
if not save_cls and not save_patch:
message = "You must save at least one of the CLS token or the patch tokens."
logger.error(message)
raise typer.Abort
images = [_load_image(p) for p in input_paths]
cls_embeddings, patch_embeddings = _infer(images, model, processor, device)
zipped = zip(output_paths, cls_embeddings, patch_embeddings)
for features_path, cls_embedding, patch_embedding in zipped:
features_path.parent.mkdir(parents=True, exist_ok=True)
kwargs = {}
if save_cls:
kwargs["cls_embeddings"] = cls_embedding
if save_patch:
kwargs["patch_embeddings"] = patch_embedding
with features_path.open("wb") as f:
np.savez(f, **kwargs)
def _infer(
images: list[Image.Image],
model: nn.Module,
processor: BitImageProcessor,
device: torch.device,
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
processed = processor(images, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**processed)
cls_embeddings = outputs.pooler_output
flat_patch_embeddings = outputs.last_hidden_state[:, 1:] # first token is CLS
reshaped_patch_embeddings = _reshape_patch_embeddings(
flat_patch_embeddings,
image_size=processor.crop_size["height"],
patch_size=model.config.patch_size,
)
return cls_embeddings.cpu().numpy(), reshaped_patch_embeddings.cpu().numpy()
def _reshape_patch_embeddings(
flat_tokens: torch.Tensor,
*,
image_size: int,
patch_size: int,
) -> torch.Tensor:
"""Reshape flat list of patch tokens into a nice grid."""
embeddings_size = image_size // patch_size
patches_grid = rearrange(flat_tokens, "b (h w) c -> b c h w", h=embeddings_size)
return patches_grid
@cache
def _get_model_and_processor(
model_name: str,
device: torch.device,
) -> tuple[AutoModel, BitImageProcessor]:
repo = f"microsoft/{model_name}"
model = AutoModel.from_pretrained(repo).to(device).eval()
processor = BitImageProcessor.from_pretrained(repo)
return model, processor
def _load_image(image_path: Path) -> Image.Image:
if image_path.suffix == ".dcm":
image = _load_dicom(image_path)
else:
image = Image.open(image_path)
return image
def _load_dicom(image_path: Path) -> Image.Image:
image = read_image(image_path)
image = F.enhance_contrast(image, num_bits=8)
array = sitk.GetArrayFromImage(image)
array = np.squeeze(array)
return Image.fromarray(array)
def _get_device() -> torch.device:
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
return device
class BarlessProgress(Progress):
def __init__(self, *args, **kwargs):
columns = [
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
TimeElapsedColumn(),
]
super().__init__(*columns, *args, **kwargs)
class BarProgress(Progress):
def __init__(self, *args, **kwargs):
columns = [
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
]
super().__init__(*columns, *args, **kwargs)
if __name__ == "__main__":
app()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment