Skip to content

Instantly share code, notes, and snippets.

@se7oluti0n
Created April 25, 2025 06:59
Show Gist options
  • Save se7oluti0n/ce25034e3e6753297861004662c53250 to your computer and use it in GitHub Desktop.
Save se7oluti0n/ce25034e3e6753297861004662c53250 to your computer and use it in GitHub Desktop.
Snippet to measure FPS for ONNX
import onnxruntime as ort
import numpy as np
import time
import argparse
import os
def parse_args():
parser = argparse.ArgumentParser(description="Benchmark FPS for an ONNX model.")
parser.add_argument("--model-path", type=str, default="model.onnx",
help="Path to the ONNX model file (default: model.onnx)")
parser.add_argument("--input-size", type=int, default=224,
help="Input image size (e.g., 224 for 224x224) (default: 224)")
return parser.parse_args()
def get_execution_provider():
available_providers = ort.get_available_providers()
print(f"Available providers: {available_providers}")
if "CUDAExecutionProvider" in available_providers:
return ["CUDAExecutionProvider"]
elif "CPUExecutionProvider" in available_providers:
print("Warning: CUDAExecutionProvider not available, falling back to CPUExecutionProvider")
return ["CPUExecutionProvider"]
else:
raise RuntimeError("No suitable execution provider available")
def main():
# Parse command-line arguments
args = parse_args()
# Validate model path
if not os.path.isfile(args.model_path):
raise FileNotFoundError(f"Model file not found: {args.model_path}")
# Validate input size
if args.input_size <= 0:
raise ValueError("Input size must be a positive integer")
# Select execution provider
providers = get_execution_provider()
# Initialize ONNX session
session = ort.InferenceSession(args.model_path, providers=providers)
# Get the model's input name dynamically
input_name = session.get_inputs()[0].name
print(f"Model expects input name: {input_name}")
# Create dummy input data
input_data = np.random.randn(1, 3, args.input_size, args.input_size).astype(np.float32)
# Benchmark FPS
num_iterations = 100
start = time.time()
for _ in range(num_iterations):
session.run(None, {input_name: input_data})
fps = num_iterations / (time.time() - start)
print(f"Measured FPS: {fps:.2f} (Model: {args.model_path}, Input Size: {args.input_size}x{args.input_size}, Provider: {providers[0]})")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment