Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created August 24, 2024 19:48
Show Gist options
  • Save a-r-r-o-w/4b878f9f8d4357b7367d9be2c7d8afaa to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/4b878f9f8d4357b7367d9be2c7d8afaa to your computer and use it in GitHub Desktop.
Demonstrates how to run 49-frame inference with CogVideoX in 8 GB
# Install torchao from source and Pytorch Nightly
# Other environments have not yet been tested.
import gc
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXPipeline
from diffusers.utils import export_to_video
from transformers import T5EncoderModel
from torchao.quantization import (
quantize_,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_weight_only,
int8_dynamic_activation_int8_weight,
)
def reset_memory(device):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory=:.3f}")
print(f"{max_memory=:.3f}")
print(f"{max_reserved=:.3f}")
# Either "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
model_id = "THUDM/CogVideoX-5b"
quantization = int8_weight_only
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
quantize_(text_encoder, quantization())
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
quantize_(transformer, quantization())
vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
quantize_(vae, quantization())
# Create pipeline and run inference
pipe = CogVideoXPipeline.from_pretrained(
model_id,
text_encoder=text_encoder,
transformer=transformer,
vae=vae,
torch_dtype=torch.bfloat16,
)
if quantization == int4_weight_only:
pipe.to("cuda")
else:
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
reset_memory("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(
prompt=prompt,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=50,
generator=torch.Generator().manual_seed(3047), # https://arxiv.org/abs/2109.08203
).frames[0]
print_memory("cuda")
export_to_video(video, "output.mp4", fps=8)
@a-r-r-o-w
Copy link
Author

a-r-r-o-w commented Aug 24, 2024

The numbers come from running the code on a single A100, 80 GB. The "max_memory" value is reported. "max_reserved" is a bit higher.

int8_weight_only: 7.976 GB
int8_dynamic_activation_int8_weight: 8.708 GB
int4_weight_only: 9.418 GB
int8_dynamic_activation_int4_weight: 11.272 GB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment