Created
February 6, 2024 11:16
-
-
Save fengyuentau/e3619eb15ac87969a539f5313c28558b to your computer and use it in GitHub Desktop.
ONNXRuntime C++ inference example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cmake_minimum_required(VERSION 3.5) | |
project(tests) | |
set(ORT_ROOT_DIR "/path/to/onnxruntime") | |
set(ORT_INCLUDE_DIR "${ORT_ROOT_DIR}/include") | |
set(ORT_BUILD_DIR "${ORT_ROOT_DIR}/build/MacOS/RelWithDebInfo") | |
#set(OCV_ROOT_DIR "/path/to/opencv") | |
#set(OCV_BUILD_DIR "${OCV_ROOT_DIR}/build/install") | |
# Find OpenCV | |
find_package(OpenCV 4.6.0 REQUIRED) | |
#include_directories(${OCV_BUILD_DIR}/include) | |
#find_library(OCV_LIB_DIR | |
# Set up for ORT | |
include_directories(${ORT_INCLUDE_DIR}) | |
find_library(ORT_LIB_DIR onnxruntime HINTS ${ORT_BUILD_DIR}) | |
# ResNet | |
add_executable(ort ort.cpp) | |
target_link_libraries(ort ${ORT_LIB_DIR} ${OpenCV_LIBS}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// ORT | |
#include <onnxruntime/core/session/onnxruntime_cxx_api.h> | |
// OCV | |
#include <opencv2/opencv.hpp> | |
#include <fstream> | |
#include <vector> | |
#include <iostream> | |
#include <numeric> | |
using namespace cv; | |
using namespace std; | |
template <typename T> | |
T vectorProduct(const std::vector<T>& v) | |
{ | |
return accumulate(v.begin(), v.end(), 1, std::multiplies<T>()); | |
} | |
int main(int argc, char** argv) | |
{ | |
int warmup = 3; | |
int mruns = 10; | |
String model_path = "/Users/fengyuantao/Workspace/fengyuentau/issues-prs/i-22788-vit/vit_b_32.sim.onnx"; | |
// Random input | |
const vector<int> input_shape{1, 3, 224, 224}; | |
Mat blob(4, input_shape.data(), CV_32FC1); | |
randu(blob, 0.f, 1.f); | |
// ORT env | |
Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "ort"); | |
Ort::SessionOptions sessionOptions; | |
int thread_num = 8; | |
sessionOptions.SetIntraOpNumThreads(thread_num); | |
std::cout << "thread_num = " << thread_num << std::endl; | |
// Load net | |
Ort::Session session(env, model_path.c_str(), sessionOptions); | |
Ort::AllocatorWithDefaultOptions allocator; | |
size_t numInputNodes = session.GetInputCount(); | |
size_t numOutputNodes = session.GetOutputCount(); | |
std::vector<const char*> inputNames{"input"}; | |
std::vector<const char*> outputNames{"output"}; | |
// handling inputs | |
Ort::TypeInfo inputTypeInfo = session.GetInputTypeInfo(0); | |
auto inputTensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo(); | |
std::vector<int64_t> inputDims = inputTensorInfo.GetShape(); | |
size_t inputTensorSize = vectorProduct(inputDims); | |
std::vector<float> inputTensorValues(inputTensorSize); | |
inputTensorValues.assign(blob.begin<float>(), | |
blob.end<float>()); // =setInput | |
// handling outputs | |
Ort::TypeInfo outputTypeInfo = session.GetOutputTypeInfo(0); | |
auto outputTensorInfo = outputTypeInfo.GetTensorTypeAndShapeInfo(); | |
std::vector<int64_t> outputDims = outputTensorInfo.GetShape(); | |
size_t outputTensorSize = vectorProduct(outputDims); | |
std::vector<float> outputTensorValues(outputTensorSize); | |
// handling memory | |
Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu( | |
OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); | |
std::vector<Ort::Value> inputTensors; | |
inputTensors.push_back(Ort::Value::CreateTensor<float>( | |
memoryInfo, inputTensorValues.data(), inputTensorSize, inputDims.data(), | |
inputDims.size())); | |
std::vector<Ort::Value> outputTensors; | |
outputTensors.push_back(Ort::Value::CreateTensor<float>( | |
memoryInfo, outputTensorValues.data(), outputTensorSize, | |
outputDims.data(), outputDims.size())); | |
// Forward | |
// Forward -- warmup 10 times | |
for (int i = 0; i < warmup; i++) | |
{ | |
session.Run(Ort::RunOptions{nullptr}, | |
inputNames.data(), inputTensors.data(), 1, | |
outputNames.data(), outputTensors.data(), 1); | |
} | |
// Forward -- run 30 times for measurement | |
std::vector<double> times; | |
TickMeter tm; | |
for (int i = 0; i < mruns; i++) | |
{ | |
tm.reset(); | |
tm.start(); | |
session.Run(Ort::RunOptions{nullptr}, | |
inputNames.data(), inputTensors.data(), 1, | |
outputNames.data(), outputTensors.data(), 1); | |
tm.stop(); | |
times.push_back(tm.getTimeMilli()); | |
} | |
double sum = 0; | |
double min = times[0]; | |
for (auto t : times) | |
{ | |
sum += t; | |
if (min > t) | |
min = t; | |
} | |
double mean = sum / times.size(); | |
double median = (times[4] + times[5]) / 2; | |
std::cout << cv::format("mean = %f, median = %f, min = %f\n", mean, median, min); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment