Skip to content

Instantly share code, notes, and snippets.

@hcho3
Last active October 10, 2024 20:41
Show Gist options
  • Save hcho3/843f02b54fc490bef19d17ead634c919 to your computer and use it in GitHub Desktop.
Save hcho3/843f02b54fc490bef19d17ead634c919 to your computer and use it in GitHub Desktop.
Comparing performance of FIL vs Hummingbird
import pickle
import time
from argparse import ArgumentParser
import numpy as np
from cuml import ForestInference
from cuml.experimental import ForestInference as ForestInferenceNG
from hummingbird.ml import convert, load
N_WARMUP = 20
N_TRIAL = 100
BATCH_SIZE = 100000
def compute_stats(elapsed_time, msg):
print(
f"{msg}: avg = {np.mean(elapsed_time)}, median = {np.median(elapsed_time)}, "
f"std = {np.std(elapsed_time)}"
)
def old_fil(X, clf):
fm = ForestInference.load_from_sklearn(
clf,
output_class=False,
algo="naive",
storage_type=True,
threads_per_tree=32,
)
for _ in range(N_WARMUP):
fm.predict(X)
elapsed_time = []
for _ in range(N_TRIAL):
tstart = time.perf_counter()
fm.predict(X)
tend = time.perf_counter()
elapsed_time.append(tend - tstart)
elapsed_time = np.array(elapsed_time)
compute_stats(elapsed_time, "FIL")
def new_fil(X, clf):
fm_ng = ForestInferenceNG.load_from_sklearn(
clf,
output_class=False,
layout="depth_first",
default_chunk_size=16,
)
for _ in range(N_WARMUP):
fm_ng.predict(X)
elapsed_time = []
for _ in range(N_TRIAL):
tstart = time.perf_counter()
fm_ng.predict(X)
tend = time.perf_counter()
elapsed_time.append(tend - tstart)
elapsed_time = np.array(elapsed_time)
compute_stats(elapsed_time, "Experimental FIL")
def hummingbird(X, clf):
model = convert(clf, "pytorch")
model.to("cuda")
for _ in range(N_WARMUP):
model.predict(X)
elapsed_time = []
for _ in range(N_TRIAL):
tstart = time.perf_counter()
model.predict(X)
tend = time.perf_counter()
elapsed_time.append(tend - tstart)
elapsed_time = np.array(elapsed_time)
compute_stats(elapsed_time, "Hummingbird")
def main(args):
with open("model.pkl", "rb") as f:
clf = pickle.load(f)
X = np.load("X.npy")[:BATCH_SIZE, :]
if args.backend == "new_fil":
new_fil(X, clf)
elif args.backend == "old_fil":
old_fil(X, clf)
elif args.backend == "hummingbird":
hummingbird(X, clf)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--backend",
type=str,
choices=["old_fil", "new_fil", "hummingbird"],
required=True,
help="Inference backend to use",
)
parsed_args = parser.parse_args()
main(parsed_args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment