Skip to content

Instantly share code, notes, and snippets.

@Athospd
Created April 5, 2025 17:22
Show Gist options
  • Save Athospd/9f9245cfc65847ab652eb4f7b0519b4c to your computer and use it in GitHub Desktop.
Save Athospd/9f9245cfc65847ab652eb4f7b0519b4c to your computer and use it in GitHub Desktop.
Comparison between 1 single model against two individual models for a given set of scenarios
# Load necessary libraries
library(tidymodels)
library(bonsai)
library(tidyverse)
library(ggplot2)
library(patchwork)
f <- glue::glue
theme_set(theme_minimal())
logit <- function(x) 1/(1 + exp(-x))
# Create dummy data
set.seed(42)
n = 20000
data <- tibble(
scenario = sample(c("A", "B"), n, replace = TRUE, prob = c(0.2, 0.8)),
x1 = runif(n),
x2 = runif(n)
) %>%
mutate(
x1_and_scenario_interaction = (1 + -2*x1)*(scenario == "A") + (0 + 1*x1)*(scenario == "B"),
lin_pred = 0 + x1_and_scenario_interaction + 3*x2,
target = as.factor(rbinom(n, 1, logit(lin_pred)))
)
data %>% count(target)
# Split the data into training and testing sets
data_split <- initial_split(data, prop = 0.3)
train_data <- training(data_split)
test_data <- testing(data_split)
model_lgbm <- boost_tree(
mode = "classification",
trees = 900,
tree_depth = 8,
learn_rate = 0.01,
engine = "lightgbm"
)
recipe_lgbm <- recipe(target ~ ., data = train_data) %>%
step_zv(all_predictors()) %>%
step_rm(x1_and_scenario_interaction) %>%
step_rm(lin_pred)
workflow_lgbm <- workflow() %>%
add_model(model_lgbm) %>%
add_recipe(recipe_lgbm)
# Fit the models
fitted_model_for_A <- workflow_lgbm %>% fit(data = train_data %>% filter(scenario == "A"))
fitted_model_for_B <- workflow_lgbm %>% fit(data = train_data %>% filter(scenario == "B"))
fitted_model_for_A_and_B <- workflow_lgbm %>% fit(data = train_data)
# Make predictions on the test set and calculate ROC AUC
test_data_prepped <- bake(extract_recipe(fitted_model_for_A_and_B), new_data = test_data)
predictions <- bind_rows(
pred_model_A = predict(new_data = test_data, fitted_model_for_A, type = "prob") %>% mutate(model = "model tailored for A") %>% bind_cols(test_data_prepped),
pred_model_B = predict(new_data = test_data, fitted_model_for_B, type = "prob") %>% mutate(model = "model tailored for B") %>% bind_cols(test_data_prepped),
pred_model_A_and_B = predict(new_data = test_data, fitted_model_for_A_and_B, type = "prob") %>% mutate(model = "model with scenario as feature") %>% bind_cols(test_data_prepped)
)
# Results/comparison
predictions %>%
group_by(scenario, model) %>%
summarise(
ROC_AUC_VALIDATION_SET = roc_auc_vec(target, .pred_0)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment