Created
October 24, 2017 21:38
-
-
Save frbl/5acb8570d01ffccad3b780972f701fd9 to your computer and use it in GitHub Desktop.
Example of the update of XGBoost not working.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set.seed(12345) | |
library('data.table') | |
library('xgboost') | |
## Create simulation function | |
give_me_data <- function(nobs, delta = 0.05) { | |
X_mat <- data.table(A = rnorm(nobs, 0, 1), B = rnorm(nobs, 0, 1)) | |
probs <- pmax(as.numeric(X_mat$A > 0) - 2*delta, 0) + delta | |
X_mat <- as.matrix(X_mat) | |
Y_vals <- rbinom(nobs, 1, probs) | |
return(list(X_mat = X_mat, Y_vals = Y_vals)) | |
} | |
nobs <- 20 | |
niter <- 100 | |
test_nobs <- 200 | |
## Create a training set | |
data <- give_me_data(nobs) | |
## Create a test set | |
test_data <- give_me_data(test_nobs) | |
params <- list(objective = 'binary:logistic', nthread = 1) | |
dtrain <- xgb.DMatrix(data = data$X_mat, label = data$Y_vals) | |
previous_model <- xgb.train( | |
data = dtrain, | |
params = params, | |
nrounds = 200, | |
xgb_model = NULL | |
) | |
## Calculate result | |
result <- predict(previous_model, newdata = test_data$X_mat, type='response') | |
## Still works here, predictions look fine | |
print(result) | |
params <- modifyList(params, list(process_type = 'update', updater = 'refresh', refresh_leaf = TRUE)) | |
## Update niter times | |
for (i in 1:niter) { | |
data <- give_me_data(nobs) | |
dtrain <- xgb.DMatrix(data = data$X_mat, label = data$Y_vals) | |
previous_model <- xgb.train( | |
data = dtrain, | |
params = params, | |
nrounds = 200, | |
xgb_model = previous_model, | |
) | |
result <- predict(previous_model, newdata = test_data$X_mat, type='response') | |
print(result) | |
print(i) | |
} | |
## It works for the first iteration, then I get | |
## Error in xgb.train(data = dtrain, params = params, nrounds = 10, xgb_model = previous_model,: nrounds cannot be larger than 0 (nrounds of xgb_model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment