Last active
June 3, 2023 01:29
-
-
Save hcho3/1cefe168614560106408ce6bfb9458d9 to your computer and use it in GitHub Desktop.
Predicting with learning-to-rank using FIL
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import treelite | |
import xgboost as xgb | |
from sklearn.datasets import make_classification | |
from cuml.experimental import ForestInference | |
X, y = make_classification( | |
n_samples=100, | |
n_features=5, | |
n_informative=4, | |
n_redundant=1, | |
n_repeated=0, | |
n_classes=2, | |
random_state=0, | |
) | |
X, y = X.astype("float32"), y.astype("int32") | |
dtrain = xgb.DMatrix(X, label=y) | |
dtrain.set_group([10] * 10) | |
params = { | |
"objective": "rank:pairwise", | |
"eval_metric": "map", | |
"learning_rate": 0.1, | |
"max_depth": 4, | |
} | |
bst = xgb.train(params, dtrain, num_boost_round=10, evals=[(dtrain, "train")]) | |
tl_model = treelite.Model.from_xgboost(bst) | |
fil_model = ForestInference.load_from_treelite_model(tl_model, output_class=False) | |
# Ensure that Treelite yields correct result | |
pred = bst.inplace_predict(X) | |
tl_pred = treelite.gtil.predict(tl_model, X) | |
np.testing.assert_almost_equal(pred, tl_pred, decimal=5) | |
# Ensure that FIL yields correct result | |
fil_pred = fil_model.predict(X) | |
np.testing.assert_almost_equal(pred, fil_pred.flatten(), decimal=5) | |
# Relevant docs should be ranked first, before irrelevant docs | |
for i in range(10): | |
pred_slice = pred[i * 10 : (i + 1) * 10] | |
y_slice = y[i * 10 : (i + 1) * 10] | |
print(y_slice[np.argsort(-pred_slice)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment