Last active
April 14, 2019 16:33
-
-
Save rudeboybert/0e9fafb4bdf15de9d0a4fe9e2fe0c6fd to your computer and use it in GitHub Desktop.
SDS/CSC 293 CART Code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#------------------------------------------------------------------------------ | |
# Lec14: 2019/03/25 | |
#------------------------------------------------------------------------------ | |
library(tidyverse) | |
# Pre-process iris dataset | |
iris <- iris %>% | |
# Convert to tibble data frame: | |
as_tibble() %>% | |
# Add identification variable to uniquely identify each row: | |
rownames_to_column(var="ID") | |
# Fit CART model, in this case for classification | |
library(rpart) | |
model_formula <- as.formula(Species ~ Sepal.Length + Sepal.Width) | |
tree_parameters <- rpart.control(maxdepth = 3) | |
model_CART <- rpart(model_formula, data = iris, control = tree_parameters) | |
# Plot CART model | |
plot(model_CART, margin=0.25) | |
text(model_CART, use.n = TRUE) | |
title("Predicting iris species using sepal length & width") | |
box() | |
#------------------------------------------------------------------------------ | |
# Exercises with your partner: | |
# a) If the condition at a given node of the tree evaluates to true, do you go | |
# down the left branch or the right branch? | |
# b) Note the bottom-left most "leaf" 44/1/0, corresponds to 44 setosa, 1 | |
# versicolor, 0 virginia, and thus the "majority" winner is setosa. Apply a | |
# sequence of dplyr commands to the iris data frame to end up with a data frame | |
# of 44 + 1 + 0 = 45 rows corresponding to these 45 flowers | |
# c) Read the help file for `rpart.control` and play around with different | |
# arguments that control the shape of the tree in the tree_parameters object | |
# above: | |
tree_parameters_2 <- rpart.control(CHANGE THIS) | |
# Create training (100 flowers) and test (50 flowers) | |
set.seed(76) | |
iris_train <- iris %>% | |
sample_frac(2/3) | |
iris_test <- iris %>% | |
anti_join(iris_train, by = "ID") | |
# 1.a) Fit model to train | |
model_CART_2 <- rpart(model_formula, data = iris_train, control = tree_parameters) | |
# 1.b) Plot CART model | |
plot(model_CART_2, margin = 0.25) | |
text(model_CART_2, use.n = TRUE) | |
title("Predicting iris species using sepal length & width") | |
box() | |
# 1.c) Get fitted probabilities for each class on train | |
p_hat_matrix_train <- model_CART_2 %>% | |
predict(type = "prob", newdata = iris_train) %>% | |
# Convert matrix object to data frame: | |
as_tibble() | |
p_hat_matrix_train | |
# 1.d) Look at distinct probabilities | |
p_hat_matrix_train %>% | |
distinct() | |
# 2.a) Apply model to test to get fitted probabilities for each class | |
p_hat_matrix_test <- model_CART_2 %>% | |
predict(type = "prob", newdata = iris_test) %>% | |
# Convert matrix object to data frame: | |
as_tibble() | |
p_hat_matrix_test | |
# 2.b) Instead of fitted probabilities, return fitted y's, where highest | |
# probability wins and ties are broken at random | |
y_hat <- model_CART %>% | |
predict(type="class", newdata = iris_test) %>% | |
# Function to convert a vector to a data frame | |
enframe() | |
y_hat | |
# Look at help file for the (multi-class) logarithmic loss function, which is | |
# one possible "score" for categorical variables when you have more than 2 | |
# categories. | |
library(yardstick) | |
?mn_log_loss | |
# Create a new data frame: | |
bind_cols( | |
# Observed y: | |
Species = iris_test$Species, | |
# Fitted probabilities for each class | |
p_hat_matrix_test | |
) %>% | |
# Compute multi-class log-loss | |
mn_log_loss(truth = Species, c(setosa, versicolor, virginica)) | |
#------------------------------------------------------------------------------ | |
# Exercises with your partner: | |
# d) In 1.d) you saw there are only 3 unique possible 3-tuples (i.e. triplets) | |
# of fitted probabilties. Which leaf in the tree does each of these 3 possible | |
# 3-tuples correspond to? | |
# e) Are larger (multi-class) logarithmic loss function indicative of better | |
# predictions or worse predictions? | |
#------------------------------------------------------------------------------ | |
# Solutions | |
# a) Looking at the top node of the plot of model_CART and going left, there are | |
# total of 44 + 1 + 0 + 1 + 5 + 1 = 52 flowers in all children leaves. Since | |
iris %>% | |
filter(Sepal.Length < 5.45) %>% | |
nrow() | |
# yields a data frame with 52 rows, if the boolean evaluates to true, then you | |
# go left | |
# b) Note there are 0 virginica: | |
iris %>% | |
filter(Sepal.Length < 5.45) %>% | |
filter(Sepal.Width >= 2.8) %>% | |
count(Species) | |
# c) Let's set the minsplit to 50 for example | |
tree_parameters_2 <- rpart.control(minsplit = 100) | |
model_CART_3 <- rpart(model_formula, data = iris, control = tree_parameters_2) | |
# Plot CART model. Once there are less than 100 trees at a node, we stop | |
# splitting | |
plot(model_CART_3, margin=0.25) | |
text(model_CART_3, use.n = TRUE) | |
title("Predicting iris species using sepal length & width") | |
box() | |
# d) | |
p_hat_matrix_train %>% | |
distinct() | |
# First row above is the 32/4/0 leaf, since we have probabilities of | |
# 32/36 = 0.889, 4/36 = 0.111, 0/36. The winner is setosa | |
# Second row above is the 1/19/30 leaf, thus the winner is virginica | |
# Third row above is the 3/11/0 row, thus the winnder is versicolor | |
# e) Look at: https://cdn-images-1.medium.com/max/1600/0*i2_eUc_t8A1EJObd.png if | |
# p_ij = 1, the log(p_ij) = 0, and thus sum = 0, and thus the whole thing is 0 | |
# Thus low (multi-class) logarithmic loss function are indicative of good | |
# predictions |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment