Skip to content

Instantly share code, notes, and snippets.

@hcho3
Last active April 22, 2020 21:11
Show Gist options
  • Save hcho3/f774c512b5f88a2386a4bff0c652593e to your computer and use it in GitHub Desktop.
Save hcho3/f774c512b5f88a2386a4bff0c652593e to your computer and use it in GitHub Desktop.
Benchmark script to measure round-trip serialization performance of Treelite using Protobuf
/**
* Benchmark script to measure round-trip serialization performance of Treelite using Protobuf.
* Author: Hyunsu Cho ([email protected])
*/
#include <stdio.h>
#include <stdlib.h>
#include <dlfcn.h>
#include <time.h>
typedef void* model_builder_handle;
typedef void* tree_builder_handle;
typedef void* model_handle;
const char* (*last_error) (void);
int (*create_model_builder) (int, int, int, model_builder_handle*);
int (*create_tree_builder) (tree_builder_handle*);
int (*create_node) (tree_builder_handle, int);
int (*set_test) (tree_builder_handle, int, unsigned, const char*, float, int, int, int);
int (*set_leaf) (tree_builder_handle, int, float);
int (*set_root) (tree_builder_handle, int);
int (*insert_tree) (model_builder_handle, tree_builder_handle, int);
int (*commit_model) (model_builder_handle, model_handle*);
int (*serialize_binary) (const char*, model_handle);
int (*deserialize_binary) (const char*, model_handle*);
int (*free_tree_builder) (tree_builder_handle);
int (*free_model_builder) (model_builder_handle);
int (*free_model) (model_handle);
void* load_func(void* lib_handle, const char* func_name);
void load_lib(void);
void check_call(int retval, int lineno);
double get_time(void);
#define CHECK_CALL(x) check_call(x, __LINE__)
int main(void) {
const int num_feature = 3;
const int depth = 24;
load_lib();
/* Build a full binary tree with a given depth */
double tstart = get_time();
model_builder_handle mbuilder;
CHECK_CALL(create_model_builder(num_feature, 1, 0, &mbuilder));
tree_builder_handle tbuilder;
CHECK_CALL(create_tree_builder(&tbuilder));
for (int level = 0; level <= depth; ++level) {
for (int i = 0; i < (1 << level); ++i) {
const int nid = (1 << level) - 1 + i;
CHECK_CALL(create_node(tbuilder, nid));
}
}
for (int level = 0; level <= depth; ++level) {
for (int i = 0; i < (1 << level); ++i) {
const int nid = (1 << level) - 1 + i;
const float leaf_value = 0.5;
if (level == depth) {
CHECK_CALL(set_leaf(tbuilder, nid, leaf_value));
} else {
const int fid = level % 2;
const float threshold = 0.0;
const int left_child_nid = 2 * nid + 1;
const int right_child_nid = 2 * nid + 2;
CHECK_CALL(set_test(tbuilder, nid, fid, "<", threshold, 1,
left_child_nid, right_child_nid));
}
}
}
CHECK_CALL(set_root(tbuilder, 0));
CHECK_CALL(insert_tree(mbuilder, tbuilder, -1));
CHECK_CALL(free_tree_builder(tbuilder));
model_handle model;
CHECK_CALL(commit_model(mbuilder, &model));
CHECK_CALL(free_model_builder(mbuilder));
double tend = get_time();
printf("%lf sec to construct model in memory\n", (tend - tstart));
/* Measure time it takes to perform round-trip with Protobuf */
tstart = get_time();
CHECK_CALL(serialize_binary("model.bin", model));
tend = get_time();
printf("%lf sec to serialize model to binary\n", (tend - tstart));
CHECK_CALL(free_model(model));
model_handle model2;
tstart = get_time();
CHECK_CALL(deserialize_binary("model.bin", &model2));
tend = get_time();
printf("%lf sec to deserialize model from binary\n", (tend - tstart));
CHECK_CALL(free_model(model2));
return 0;
}
void* load_func(void* lib_handle, const char* func_name) {
void* func_handle = dlsym(lib_handle, func_name);
if (!func_handle) {
fprintf(stderr, "Could not load function %s().\n", func_name);
exit(2);
}
return func_handle;
}
void load_lib(void) {
void* lib = dlopen("./libtreelite.so", RTLD_LAZY | RTLD_LOCAL);
if (!lib) {
fprintf(stderr, "Please place libtreelite.so in the same directory and try again.\n");
exit(1);
}
last_error = load_func(lib, "TreeliteGetLastError");
create_model_builder = load_func(lib, "TreeliteCreateModelBuilder");
create_tree_builder = load_func(lib, "TreeliteCreateTreeBuilder");
create_node = load_func(lib, "TreeliteTreeBuilderCreateNode");
set_test = load_func(lib, "TreeliteTreeBuilderSetNumericalTestNode");
set_leaf = load_func(lib, "TreeliteTreeBuilderSetLeafNode");
set_root = load_func(lib, "TreeliteTreeBuilderSetRootNode");
insert_tree = load_func(lib, "TreeliteModelBuilderInsertTree");
commit_model = load_func(lib, "TreeliteModelBuilderCommitModel");
serialize_binary = load_func(lib, "TreeliteSerializeModel");
deserialize_binary = load_func(lib, "TreeliteDeserializeModel");
free_tree_builder = load_func(lib, "TreeliteDeleteTreeBuilder");
free_model_builder = load_func(lib, "TreeliteDeleteModelBuilder");
free_model = load_func(lib, "TreeliteFreeModel");
}
void check_call(int retval, int lineno) {
if (retval == -1) {
fprintf(stderr, "Line %d: %s\n", lineno, last_error());
exit(1);
}
}
double get_time(void) {
struct timespec ts;
if (clock_gettime(CLOCK_REALTIME, &ts) == -1) {
fprintf(stderr, "Could not get current timestamp\n");
exit(3);
}
return (double)(ts.tv_sec) + (double)(ts.tv_nsec) * 1e-9;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment