Created
April 5, 2025 17:22
-
-
Save Athospd/9f9245cfc65847ab652eb4f7b0519b4c to your computer and use it in GitHub Desktop.
Comparison between 1 single model against two individual models for a given set of scenarios
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load necessary libraries | |
library(tidymodels) | |
library(bonsai) | |
library(tidyverse) | |
library(ggplot2) | |
library(patchwork) | |
f <- glue::glue | |
theme_set(theme_minimal()) | |
logit <- function(x) 1/(1 + exp(-x)) | |
# Create dummy data | |
set.seed(42) | |
n = 20000 | |
data <- tibble( | |
scenario = sample(c("A", "B"), n, replace = TRUE, prob = c(0.2, 0.8)), | |
x1 = runif(n), | |
x2 = runif(n) | |
) %>% | |
mutate( | |
x1_and_scenario_interaction = (1 + -2*x1)*(scenario == "A") + (0 + 1*x1)*(scenario == "B"), | |
lin_pred = 0 + x1_and_scenario_interaction + 3*x2, | |
target = as.factor(rbinom(n, 1, logit(lin_pred))) | |
) | |
data %>% count(target) | |
# Split the data into training and testing sets | |
data_split <- initial_split(data, prop = 0.3) | |
train_data <- training(data_split) | |
test_data <- testing(data_split) | |
model_lgbm <- boost_tree( | |
mode = "classification", | |
trees = 900, | |
tree_depth = 8, | |
learn_rate = 0.01, | |
engine = "lightgbm" | |
) | |
recipe_lgbm <- recipe(target ~ ., data = train_data) %>% | |
step_zv(all_predictors()) %>% | |
step_rm(x1_and_scenario_interaction) %>% | |
step_rm(lin_pred) | |
workflow_lgbm <- workflow() %>% | |
add_model(model_lgbm) %>% | |
add_recipe(recipe_lgbm) | |
# Fit the models | |
fitted_model_for_A <- workflow_lgbm %>% fit(data = train_data %>% filter(scenario == "A")) | |
fitted_model_for_B <- workflow_lgbm %>% fit(data = train_data %>% filter(scenario == "B")) | |
fitted_model_for_A_and_B <- workflow_lgbm %>% fit(data = train_data) | |
# Make predictions on the test set and calculate ROC AUC | |
test_data_prepped <- bake(extract_recipe(fitted_model_for_A_and_B), new_data = test_data) | |
predictions <- bind_rows( | |
pred_model_A = predict(new_data = test_data, fitted_model_for_A, type = "prob") %>% mutate(model = "model tailored for A") %>% bind_cols(test_data_prepped), | |
pred_model_B = predict(new_data = test_data, fitted_model_for_B, type = "prob") %>% mutate(model = "model tailored for B") %>% bind_cols(test_data_prepped), | |
pred_model_A_and_B = predict(new_data = test_data, fitted_model_for_A_and_B, type = "prob") %>% mutate(model = "model with scenario as feature") %>% bind_cols(test_data_prepped) | |
) | |
# Results/comparison | |
predictions %>% | |
group_by(scenario, model) %>% | |
summarise( | |
ROC_AUC_VALIDATION_SET = roc_auc_vec(target, .pred_0) | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment