Last active
October 10, 2024 20:41
-
-
Save hcho3/843f02b54fc490bef19d17ead634c919 to your computer and use it in GitHub Desktop.
Comparing performance of FIL vs Hummingbird
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
import time | |
from argparse import ArgumentParser | |
import numpy as np | |
from cuml import ForestInference | |
from cuml.experimental import ForestInference as ForestInferenceNG | |
from hummingbird.ml import convert, load | |
N_WARMUP = 20 | |
N_TRIAL = 100 | |
BATCH_SIZE = 100000 | |
def compute_stats(elapsed_time, msg): | |
print( | |
f"{msg}: avg = {np.mean(elapsed_time)}, median = {np.median(elapsed_time)}, " | |
f"std = {np.std(elapsed_time)}" | |
) | |
def old_fil(X, clf): | |
fm = ForestInference.load_from_sklearn( | |
clf, | |
output_class=False, | |
algo="naive", | |
storage_type=True, | |
threads_per_tree=32, | |
) | |
for _ in range(N_WARMUP): | |
fm.predict(X) | |
elapsed_time = [] | |
for _ in range(N_TRIAL): | |
tstart = time.perf_counter() | |
fm.predict(X) | |
tend = time.perf_counter() | |
elapsed_time.append(tend - tstart) | |
elapsed_time = np.array(elapsed_time) | |
compute_stats(elapsed_time, "FIL") | |
def new_fil(X, clf): | |
fm_ng = ForestInferenceNG.load_from_sklearn( | |
clf, | |
output_class=False, | |
layout="depth_first", | |
default_chunk_size=16, | |
) | |
for _ in range(N_WARMUP): | |
fm_ng.predict(X) | |
elapsed_time = [] | |
for _ in range(N_TRIAL): | |
tstart = time.perf_counter() | |
fm_ng.predict(X) | |
tend = time.perf_counter() | |
elapsed_time.append(tend - tstart) | |
elapsed_time = np.array(elapsed_time) | |
compute_stats(elapsed_time, "Experimental FIL") | |
def hummingbird(X, clf): | |
model = convert(clf, "pytorch") | |
model.to("cuda") | |
for _ in range(N_WARMUP): | |
model.predict(X) | |
elapsed_time = [] | |
for _ in range(N_TRIAL): | |
tstart = time.perf_counter() | |
model.predict(X) | |
tend = time.perf_counter() | |
elapsed_time.append(tend - tstart) | |
elapsed_time = np.array(elapsed_time) | |
compute_stats(elapsed_time, "Hummingbird") | |
def main(args): | |
with open("model.pkl", "rb") as f: | |
clf = pickle.load(f) | |
X = np.load("X.npy")[:BATCH_SIZE, :] | |
if args.backend == "new_fil": | |
new_fil(X, clf) | |
elif args.backend == "old_fil": | |
old_fil(X, clf) | |
elif args.backend == "hummingbird": | |
hummingbird(X, clf) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--backend", | |
type=str, | |
choices=["old_fil", "new_fil", "hummingbird"], | |
required=True, | |
help="Inference backend to use", | |
) | |
parsed_args = parser.parse_args() | |
main(parsed_args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment