Created
April 25, 2025 06:59
-
-
Save se7oluti0n/ce25034e3e6753297861004662c53250 to your computer and use it in GitHub Desktop.
Snippet to measure FPS for ONNX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import onnxruntime as ort | |
import numpy as np | |
import time | |
import argparse | |
import os | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Benchmark FPS for an ONNX model.") | |
parser.add_argument("--model-path", type=str, default="model.onnx", | |
help="Path to the ONNX model file (default: model.onnx)") | |
parser.add_argument("--input-size", type=int, default=224, | |
help="Input image size (e.g., 224 for 224x224) (default: 224)") | |
return parser.parse_args() | |
def get_execution_provider(): | |
available_providers = ort.get_available_providers() | |
print(f"Available providers: {available_providers}") | |
if "CUDAExecutionProvider" in available_providers: | |
return ["CUDAExecutionProvider"] | |
elif "CPUExecutionProvider" in available_providers: | |
print("Warning: CUDAExecutionProvider not available, falling back to CPUExecutionProvider") | |
return ["CPUExecutionProvider"] | |
else: | |
raise RuntimeError("No suitable execution provider available") | |
def main(): | |
# Parse command-line arguments | |
args = parse_args() | |
# Validate model path | |
if not os.path.isfile(args.model_path): | |
raise FileNotFoundError(f"Model file not found: {args.model_path}") | |
# Validate input size | |
if args.input_size <= 0: | |
raise ValueError("Input size must be a positive integer") | |
# Select execution provider | |
providers = get_execution_provider() | |
# Initialize ONNX session | |
session = ort.InferenceSession(args.model_path, providers=providers) | |
# Get the model's input name dynamically | |
input_name = session.get_inputs()[0].name | |
print(f"Model expects input name: {input_name}") | |
# Create dummy input data | |
input_data = np.random.randn(1, 3, args.input_size, args.input_size).astype(np.float32) | |
# Benchmark FPS | |
num_iterations = 100 | |
start = time.time() | |
for _ in range(num_iterations): | |
session.run(None, {input_name: input_data}) | |
fps = num_iterations / (time.time() - start) | |
print(f"Measured FPS: {fps:.2f} (Model: {args.model_path}, Input Size: {args.input_size}x{args.input_size}, Provider: {providers[0]})") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment