Skip to content

Instantly share code, notes, and snippets.

@hcho3
Last active June 3, 2023 01:29
Show Gist options
  • Save hcho3/1cefe168614560106408ce6bfb9458d9 to your computer and use it in GitHub Desktop.
Save hcho3/1cefe168614560106408ce6bfb9458d9 to your computer and use it in GitHub Desktop.
Predicting with learning-to-rank using FIL
import numpy as np
import treelite
import xgboost as xgb
from sklearn.datasets import make_classification
from cuml.experimental import ForestInference
X, y = make_classification(
n_samples=100,
n_features=5,
n_informative=4,
n_redundant=1,
n_repeated=0,
n_classes=2,
random_state=0,
)
X, y = X.astype("float32"), y.astype("int32")
dtrain = xgb.DMatrix(X, label=y)
dtrain.set_group([10] * 10)
params = {
"objective": "rank:pairwise",
"eval_metric": "map",
"learning_rate": 0.1,
"max_depth": 4,
}
bst = xgb.train(params, dtrain, num_boost_round=10, evals=[(dtrain, "train")])
tl_model = treelite.Model.from_xgboost(bst)
fil_model = ForestInference.load_from_treelite_model(tl_model, output_class=False)
# Ensure that Treelite yields correct result
pred = bst.inplace_predict(X)
tl_pred = treelite.gtil.predict(tl_model, X)
np.testing.assert_almost_equal(pred, tl_pred, decimal=5)
# Ensure that FIL yields correct result
fil_pred = fil_model.predict(X)
np.testing.assert_almost_equal(pred, fil_pred.flatten(), decimal=5)
# Relevant docs should be ranked first, before irrelevant docs
for i in range(10):
pred_slice = pred[i * 10 : (i + 1) * 10]
y_slice = y[i * 10 : (i + 1) * 10]
print(y_slice[np.argsort(-pred_slice)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment