Skip to content

Instantly share code, notes, and snippets.

@hcho3
Created May 24, 2020 07:43
Show Gist options
  • Save hcho3/47cd3ed4501cc007215b65d8305cc0d4 to your computer and use it in GitHub Desktop.
Save hcho3/47cd3ed4501cc007215b65d8305cc0d4 to your computer and use it in GitHub Desktop.
Early stopping in XGBoost with cross-validation
import xgboost as xgb
import numpy as np
from sklearn.model_selection import KFold
from sklearn.datasets import load_boston
nfold = 5 # number of cross-validation (CV) folds
nround = 10000 # number of boosting rounds
# Validation metric needs to improve at least once in every early_stopping_rounds rounds to
# continue training.
early_stopping_rounds = 100
X, y = load_boston(return_X_y=True)
kfold_gen = KFold(n_splits=nfold, shuffle=True, random_state=1)
dtrain = [] # dtrain[i]: training set in i-th CV fold
dvalid = [] # dvalid[i]: validation set in i-th CV fold
bst = [] # bst[i]: XGBoost model fit using i-th CV fold
for train_idx, valid_idx in kfold_gen.split(X):
dtrain.append(xgb.DMatrix(X[train_idx, :], label=y[train_idx]))
dvalid.append(xgb.DMatrix(X[valid_idx, :], label=y[valid_idx]))
params = {'objective': 'reg:squarederror', 'max_depth': 6, 'learning_rate': 0.1, 'seed': 0}
for fold_id in range(nfold):
bst.append(xgb.Booster(params, [dtrain[fold_id], dvalid[fold_id]]))
best_iteration = 0
best_score = float('inf')
for i in range(nround):
valid_metric = []
for fold_id in range(nfold):
bst[fold_id].update(dtrain[fold_id], i)
msg = bst[fold_id].eval_set([(dvalid[fold_id], 'valid')], i)
valid_metric.append(float([x.split(':') for x in msg.split()][1][1]))
cv_valid_metric = np.mean(valid_metric)
if cv_valid_metric < best_score:
best_score = cv_valid_metric
best_iteration = i
elif i - best_iteration >= early_stopping_rounds:
print(f'Stopping. Best iteration: {best_iteration}')
break
print(f'Iteration {i:-3d}, mean validation RMSE = {cv_valid_metric:-13.10f}, ' +
f'Best iteration: {best_iteration:-3d} (mean validation RMSE {best_score:-13.10f})')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment