Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created February 21, 2025 03:46
Show Gist options
  • Save a-r-r-o-w/f5c9fb5c515d24f9a06001adb5c6cf18 to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/f5c9fb5c515d24f9a06001adb5c6cf18 to your computer and use it in GitHub Desktop.
Tests multiple offloading mechanisms and gathers there CPU and CUDA memory/time usage on a single A100 GPU for Flux
import argparse
import functools
import json
import os
import pathlib
import psutil
import time
import torch
from diffusers import FluxPipeline
from diffusers.hooks import apply_group_offloading
from memory_profiler import profile
def get_memory_usage():
process = psutil.Process(os.getpid())
mem_bytes = process.memory_info().rss
return mem_bytes
@profile(precision=2)
def apply_offload(pipe: FluxPipeline, method: str) -> None:
if method == "full_cuda":
pipe.to("cuda")
elif method == "model_offload":
pipe.enable_model_cpu_offload()
elif method == "sequential_offload":
pipe.enable_sequential_cpu_offload()
elif method == "group_offload_block_1":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="block_level",
num_blocks_per_group=1,
use_stream=False,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]))
pipe.vae.to("cuda")
elif method == "group_offload_leaf":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="leaf_level",
use_stream=False,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]))
pipe.vae.to("cuda")
elif method == "group_offload_block_1_stream":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="block_level",
num_blocks_per_group=1,
use_stream=True,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]))
pipe.vae.to("cuda")
elif method == "group_offload_leaf_stream":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="leaf_level",
use_stream=True,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]))
pipe.vae.to("cuda")
@profile(precision=2)
def load_pipeline():
cache_dir = "/raid/.cache/huggingface"
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-Dev", torch_dtype=torch.bfloat16, cache_dir=cache_dir)
return pipe
@torch.no_grad()
def main(args):
pipe = load_pipeline()
apply_offload(pipe, args.method)
apply_offload_memory_usage = get_memory_usage()
torch.cuda.reset_peak_memory_stats()
cuda_model_memory = torch.cuda.max_memory_reserved()
output_dir = pathlib.Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
run_inference_memory_usage_list = []
def cpu_mem_callback():
nonlocal run_inference_memory_usage_list
run_inference_memory_usage_list.append(get_memory_usage())
@profile(precision=2)
def run_inference():
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=6.0,
generator=torch.Generator().manual_seed(42),
callback_on_step_end=lambda *args, **kwargs: [cpu_mem_callback(), kwargs][1],
).images[0]
image.save(output_dir / f"output_{args.method}.png")
t1 = time.time()
run_inference()
torch.cuda.synchronize()
t2 = time.time()
cuda_inference_memory = torch.cuda.max_memory_reserved()
time_required = t2 - t1
run_inference_memory_usage = sum(run_inference_memory_usage_list) / len(run_inference_memory_usage_list)
print(f"Run inference memory usage list: {run_inference_memory_usage_list}")
info = {
"time": round(time_required, 2),
"cuda_model_memory": round(cuda_model_memory / 1024**3, 2),
"cuda_inference_memory": round(cuda_inference_memory / 1024**3, 2),
"cpu_offload_memory": round(apply_offload_memory_usage / 1024**3, 2),
"cpu_inference_memory": round(run_inference_memory_usage / 1024**3, 2),
}
with open(output_dir / f"memory_usage_{args.method}.json", "w") as f:
json.dump(info, f, indent=4)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--method", type=str, default="full_cuda", choices=["full_cuda", "model_offload", "sequential_offload", "group_offload_block_1", "group_offload_leaf", "group_offload_block_1_stream", "group_offload_leaf_stream"])
parser.add_argument("--output_dir", type=str, default="offload_profiling")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args)
@a-r-r-o-w
Copy link
Author

For group offloading, only offloads text encoders and transformer. VAE is always on the GPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment