Skip to content

Instantly share code, notes, and snippets.

@CNuge
Last active December 14, 2017 19:38
Show Gist options
  • Save CNuge/32baacb89b132cde54b9c12bf9f02249 to your computer and use it in GitHub Desktop.
Save CNuge/32baacb89b132cde54b9c12bf9f02249 to your computer and use it in GitHub Desktop.
Gradient Boosting and Parameter Tuning in R
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "40e13908-10a6-40d4-90a4-b11cef8dd283",
"_uuid": "9b99c05c450d4da5d7a62e8f60e391c35f153ddc",
"collapsed": true
},
"source": [
"# Machine Learning in R: part 2\n",
"\n",
"## This week's tutorial focuses on improvement. We will use some tuning methods to improve the predictive ability of the models we build.\n",
"\n",
"\n",
"## 1a. Cleaning and formatting the data (from last week)\n",
"\n",
"Below is the code from the loading, cleaning and train/test split sections we went over last week. This is all we requre to get the data into the format needed to being training some machine learning models."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"_cell_guid": "f824f548-5141-4642-b183-3dd90c8497cc",
"_uuid": "ef290f86e40365e2130d73299065a33f3475e922"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"── Attaching packages ─────────────────────────────────────── tidyverse 1.2.1 ──\n",
"✔ ggplot2 2.2.1.9000 ✔ purrr 0.2.4 \n",
"✔ tibble 1.3.4 ✔ dplyr 0.7.4 \n",
"✔ tidyr 0.7.2 ✔ stringr 1.2.0 \n",
"✔ readr 1.1.1.9000 ✔ forcats 0.2.0 \n",
"── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──\n",
"✖ dplyr::filter() masks stats::filter()\n",
"✖ dplyr::lag() masks stats::lag()\n",
"\n",
"Attaching package: ‘xgboost’\n",
"\n",
"The following object is masked from ‘package:dplyr’:\n",
"\n",
" slice\n",
"\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
"<thead><tr><th></th><th scope=col>NEAR BAY</th><th scope=col>&lt;1H OCEAN</th><th scope=col>INLAND</th><th scope=col>NEAR OCEAN</th><th scope=col>ISLAND</th><th scope=col>longitude</th><th scope=col>latitude</th><th scope=col>housing_median_age</th><th scope=col>population</th><th scope=col>households</th><th scope=col>median_income</th><th scope=col>mean_bedrooms</th><th scope=col>mean_rooms</th><th scope=col>median_house_value</th></tr></thead>\n",
"<tbody>\n",
"\t<tr><th scope=row>2418</th><td>0 </td><td>0 </td><td>1 </td><td>0 </td><td>0 </td><td> 0.06473791</td><td> 0.4485767 </td><td>-0.05081113</td><td>-0.08342596</td><td>-0.50882695</td><td>-1.2394168 </td><td>-0.03648780</td><td>-0.4145713 </td><td> 56700 </td></tr>\n",
"\t<tr><th scope=row>9990</th><td>0 </td><td>0 </td><td>1 </td><td>0 </td><td>0 </td><td>-0.74882545</td><td> 1.6471053 </td><td>-1.08374113</td><td> 1.39212008</td><td> 2.14071836</td><td>-0.7358959 </td><td>-0.19291092</td><td>-0.1004065 </td><td>143400 </td></tr>\n",
"\t<tr><th scope=row>13440</th><td>0 </td><td>0 </td><td>1 </td><td>0 </td><td>0 </td><td> 1.07295753</td><td>-0.7218613 </td><td>-0.05081113</td><td> 0.28656434</td><td> 0.06136148</td><td> 0.1404495 </td><td>-0.18700644</td><td> 0.2732884 </td><td>128300 </td></tr>\n",
"\t<tr><th scope=row>1412</th><td>1 </td><td>0 </td><td>0 </td><td>0 </td><td>0 </td><td>-1.25293526</td><td> 1.0759315 </td><td> 0.50538194</td><td> 0.35897294</td><td> 0.42492199</td><td> 0.6344959 </td><td>-0.11581168</td><td> 0.2741324 </td><td>233200 </td></tr>\n",
"\t<tr><th scope=row>7539</th><td>0 </td><td>1 </td><td>0 </td><td>0 </td><td>0 </td><td> 0.67865382</td><td>-0.8061329 </td><td>-0.20972344</td><td> 1.03802435</td><td> 0.21829408</td><td>-1.0991931 </td><td>-0.03247975</td><td>-0.5151724 </td><td>110200 </td></tr>\n",
"\t<tr><th scope=row>4621</th><td>0 </td><td>1 </td><td>0 </td><td>0 </td><td>0 </td><td> 0.62874196</td><td>-0.7265431 </td><td> 1.61776810</td><td> 0.10024464</td><td> 0.24706505</td><td>-0.6573622 </td><td>-0.07763347</td><td>-0.4598522 </td><td>350900 </td></tr>\n",
"</tbody>\n",
"</table>\n"
],
"text/latex": [
"\\begin{tabular}{r|llllllllllllll}\n",
" & NEAR BAY & <1H OCEAN & INLAND & NEAR OCEAN & ISLAND & longitude & latitude & housing\\_median\\_age & population & households & median\\_income & mean\\_bedrooms & mean\\_rooms & median\\_house\\_value\\\\\n",
"\\hline\n",
"\t2418 & 0 & 0 & 1 & 0 & 0 & 0.06473791 & 0.4485767 & -0.05081113 & -0.08342596 & -0.50882695 & -1.2394168 & -0.03648780 & -0.4145713 & 56700 \\\\\n",
"\t9990 & 0 & 0 & 1 & 0 & 0 & -0.74882545 & 1.6471053 & -1.08374113 & 1.39212008 & 2.14071836 & -0.7358959 & -0.19291092 & -0.1004065 & 143400 \\\\\n",
"\t13440 & 0 & 0 & 1 & 0 & 0 & 1.07295753 & -0.7218613 & -0.05081113 & 0.28656434 & 0.06136148 & 0.1404495 & -0.18700644 & 0.2732884 & 128300 \\\\\n",
"\t1412 & 1 & 0 & 0 & 0 & 0 & -1.25293526 & 1.0759315 & 0.50538194 & 0.35897294 & 0.42492199 & 0.6344959 & -0.11581168 & 0.2741324 & 233200 \\\\\n",
"\t7539 & 0 & 1 & 0 & 0 & 0 & 0.67865382 & -0.8061329 & -0.20972344 & 1.03802435 & 0.21829408 & -1.0991931 & -0.03247975 & -0.5151724 & 110200 \\\\\n",
"\t4621 & 0 & 1 & 0 & 0 & 0 & 0.62874196 & -0.7265431 & 1.61776810 & 0.10024464 & 0.24706505 & -0.6573622 & -0.07763347 & -0.4598522 & 350900 \\\\\n",
"\\end{tabular}\n"
],
"text/markdown": [
"\n",
"| <!--/--> | NEAR BAY | <1H OCEAN | INLAND | NEAR OCEAN | ISLAND | longitude | latitude | housing_median_age | population | households | median_income | mean_bedrooms | mean_rooms | median_house_value | \n",
"|---|---|---|---|---|---|\n",
"| 2418 | 0 | 0 | 1 | 0 | 0 | 0.06473791 | 0.4485767 | -0.05081113 | -0.08342596 | -0.50882695 | -1.2394168 | -0.03648780 | -0.4145713 | 56700 | \n",
"| 9990 | 0 | 0 | 1 | 0 | 0 | -0.74882545 | 1.6471053 | -1.08374113 | 1.39212008 | 2.14071836 | -0.7358959 | -0.19291092 | -0.1004065 | 143400 | \n",
"| 13440 | 0 | 0 | 1 | 0 | 0 | 1.07295753 | -0.7218613 | -0.05081113 | 0.28656434 | 0.06136148 | 0.1404495 | -0.18700644 | 0.2732884 | 128300 | \n",
"| 1412 | 1 | 0 | 0 | 0 | 0 | -1.25293526 | 1.0759315 | 0.50538194 | 0.35897294 | 0.42492199 | 0.6344959 | -0.11581168 | 0.2741324 | 233200 | \n",
"| 7539 | 0 | 1 | 0 | 0 | 0 | 0.67865382 | -0.8061329 | -0.20972344 | 1.03802435 | 0.21829408 | -1.0991931 | -0.03247975 | -0.5151724 | 110200 | \n",
"| 4621 | 0 | 1 | 0 | 0 | 0 | 0.62874196 | -0.7265431 | 1.61776810 | 0.10024464 | 0.24706505 | -0.6573622 | -0.07763347 | -0.4598522 | 350900 | \n",
"\n",
"\n"
],
"text/plain": [
" NEAR BAY <1H OCEAN INLAND NEAR OCEAN ISLAND longitude latitude \n",
"2418 0 0 1 0 0 0.06473791 0.4485767\n",
"9990 0 0 1 0 0 -0.74882545 1.6471053\n",
"13440 0 0 1 0 0 1.07295753 -0.7218613\n",
"1412 1 0 0 0 0 -1.25293526 1.0759315\n",
"7539 0 1 0 0 0 0.67865382 -0.8061329\n",
"4621 0 1 0 0 0 0.62874196 -0.7265431\n",
" housing_median_age population households median_income mean_bedrooms\n",
"2418 -0.05081113 -0.08342596 -0.50882695 -1.2394168 -0.03648780 \n",
"9990 -1.08374113 1.39212008 2.14071836 -0.7358959 -0.19291092 \n",
"13440 -0.05081113 0.28656434 0.06136148 0.1404495 -0.18700644 \n",
"1412 0.50538194 0.35897294 0.42492199 0.6344959 -0.11581168 \n",
"7539 -0.20972344 1.03802435 0.21829408 -1.0991931 -0.03247975 \n",
"4621 1.61776810 0.10024464 0.24706505 -0.6573622 -0.07763347 \n",
" mean_rooms median_house_value\n",
"2418 -0.4145713 56700 \n",
"9990 -0.1004065 143400 \n",
"13440 0.2732884 128300 \n",
"1412 0.2741324 233200 \n",
"7539 -0.5151724 110200 \n",
"4621 -0.4598522 350900 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"library(tidyverse)\n",
"library(xgboost)\n",
"\n",
"housing = read.csv('../input/housing.csv')\n",
"\n",
"housing$total_bedrooms[is.na(housing$total_bedrooms)] = median(housing$total_bedrooms , na.rm = TRUE)\n",
"\n",
"housing$mean_bedrooms = housing$total_bedrooms/housing$households\n",
"housing$mean_rooms = housing$total_rooms/housing$households\n",
"\n",
"drops = c('total_bedrooms', 'total_rooms')\n",
"\n",
"housing = housing[ , !(names(housing) %in% drops)]\n",
"\n",
"categories = unique(housing$ocean_proximity)\n",
"#split the categories off\n",
"cat_housing = data.frame(ocean_proximity = housing$ocean_proximity)\n",
"\n",
"for(cat in categories){\n",
" cat_housing[,cat] = rep(0, times= nrow(cat_housing))\n",
"}\n",
"\n",
"for(i in 1:length(cat_housing$ocean_proximity)){\n",
" cat = as.character(cat_housing$ocean_proximity[i])\n",
" cat_housing[,cat][i] = 1\n",
"}\n",
"\n",
"cat_columns = names(cat_housing)\n",
"keep_columns = cat_columns[cat_columns != 'ocean_proximity']\n",
"cat_housing = select(cat_housing,one_of(keep_columns))\n",
"drops = c('ocean_proximity','median_house_value')\n",
"housing_num = housing[ , !(names(housing) %in% drops)]\n",
"\n",
"\n",
"scaled_housing_num = scale(housing_num)\n",
"\n",
"cleaned_housing = cbind(cat_housing, scaled_housing_num, median_house_value=housing$median_house_value)\n",
"\n",
"\n",
"set.seed(19) # Set a random seed so that same sample can be reproduced in future runs\n",
"\n",
"sample = sample.int(n = nrow(cleaned_housing), size = floor(.8*nrow(cleaned_housing)), replace = F)\n",
"train = cleaned_housing[sample, ] #just the samples\n",
"test = cleaned_housing[-sample, ] #everything but the samples\n",
"\n",
"\n",
"train_y = train[,'median_house_value']\n",
"train_x = train[, names(train) !='median_house_value']\n",
"\n",
"test_y = test[,'median_house_value']\n",
"test_x = test[, names(test) !='median_house_value']\n",
"\n",
"head(train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1b. Cleaning - The tidyverse way! \n",
"\n",
"The code below does the same thing as the code above, but employs the tidyverse. I've pulled out the bare bones parts of Karl's 'Housing_R_tidy.r' script needed to get the data to where we want it (i.e. removed all the graphs head commands etc. Go to his original script to see the notes and visual blandishments).\n",
"\n",
"I like this code more then my original version because:\n",
"1. It is easy to follow the workflow. magrittr makes it easy to see when one cleaning task ends and the next begins.\n",
"2. It is more concise.\n",
"3. The use of comments(#) after the pipes(%>%) looks professional and also makes the code more readable. Being able to share your code with others and have them understand it is very important!\n",
"4. It runs faster.\n",
"\n",
"Note: in the tibble docs the function is listed as: as_tibble() not as.tibble() as first written. Oddly as.tibble() worked in my normal R deployment, but threw an error in the jupyter notebook :\\ This confused me and I don't know what to make of it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"library(tidyverse)\n",
"\n",
"housing.tidy = read_csv('housing.csv')\n",
"\n",
"housing.tidy = housing.tidy %>% \n",
" mutate(total_bedrooms = ifelse(is.na(total_bedrooms), \n",
" median(total_bedrooms, na.rm = T),\n",
" total_bedrooms),\n",
" mean_bedrooms = total_bedrooms/households,\n",
" mean_rooms = total_rooms/households) %>%\n",
" select(-c(total_rooms, total_bedrooms))\n",
"\n",
"\n",
"categories = unique(housing.tidy$ocean_proximity) # all categories\n",
"\n",
"cat_housing.tidy = categories %>% # compare the full vector against each category consecutively\n",
" lapply(function(x) as.numeric(housing.tidy$ocean_proximity == x)) %>% # convert to numeric\n",
" do.call(\"cbind\", .) %>% as_tibble() # clean up\n",
"colnames(cat_housing.tidy) = categories # make nice column names\n",
"\n",
"cleaned_housing.tidy = housing.tidy %>% \n",
" select(-c(ocean_proximity, median_house_value)) %>%\n",
" scale() %>% as_tibble() %>%\n",
" bind_cols(cat_housing.tidy) %>%\n",
" add_column(median_house_value = housing.tidy$median_house_value)\n",
"\n",
"set.seed(19) # Set a random seed so that same sample can be reproduced in future runs\n",
"\n",
"sample = sample.int(n = nrow(cleaned_housing.tidy), size = floor(.8*nrow(cleaned_housing.tidy)), replace = F)\n",
"train = cleaned_housing.tidy[sample, ] #just the samples\n",
"test = cleaned_housing.tidy[-sample, ] #everything but the samples\n",
" \n",
"head(train)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "35a0b9db-912d-4bf9-8210-1bbebc43c400",
"_uuid": "c20285da3cc492fe35d87fe29571cb7baea9820f"
},
"source": [
"## 2a. Last week's random forest model\n",
"\n",
"This is the model we made at the end of last week's session I'm running it again here so that we have a benchmark to for the other models we train today."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"_cell_guid": "43aafdcc-ba52-4136-89b0-e207325dba23",
"_uuid": "898e3c2f7c52bc6ab1c2d02a7c906e620b332898"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"randomForest 4.6-12\n",
"Type rfNews() to see new features/changes/bug fixes.\n",
"\n",
"Attaching package: ‘randomForest’\n",
"\n",
"The following object is masked from ‘package:dplyr’:\n",
"\n",
" combine\n",
"\n",
"The following object is masked from ‘package:ggplot2’:\n",
"\n",
" margin\n",
"\n"
]
},
{
"data": {
"text/html": [
"<ol class=list-inline>\n",
"\t<li>'call'</li>\n",
"\t<li>'type'</li>\n",
"\t<li>'predicted'</li>\n",
"\t<li>'mse'</li>\n",
"\t<li>'rsq'</li>\n",
"\t<li>'oob.times'</li>\n",
"\t<li>'importance'</li>\n",
"\t<li>'importanceSD'</li>\n",
"\t<li>'localImportance'</li>\n",
"\t<li>'proximity'</li>\n",
"\t<li>'ntree'</li>\n",
"\t<li>'mtry'</li>\n",
"\t<li>'forest'</li>\n",
"\t<li>'coefs'</li>\n",
"\t<li>'y'</li>\n",
"\t<li>'test'</li>\n",
"\t<li>'inbag'</li>\n",
"</ol>\n"
],
"text/latex": [
"\\begin{enumerate*}\n",
"\\item 'call'\n",
"\\item 'type'\n",
"\\item 'predicted'\n",
"\\item 'mse'\n",
"\\item 'rsq'\n",
"\\item 'oob.times'\n",
"\\item 'importance'\n",
"\\item 'importanceSD'\n",
"\\item 'localImportance'\n",
"\\item 'proximity'\n",
"\\item 'ntree'\n",
"\\item 'mtry'\n",
"\\item 'forest'\n",
"\\item 'coefs'\n",
"\\item 'y'\n",
"\\item 'test'\n",
"\\item 'inbag'\n",
"\\end{enumerate*}\n"
],
"text/markdown": [
"1. 'call'\n",
"2. 'type'\n",
"3. 'predicted'\n",
"4. 'mse'\n",
"5. 'rsq'\n",
"6. 'oob.times'\n",
"7. 'importance'\n",
"8. 'importanceSD'\n",
"9. 'localImportance'\n",
"10. 'proximity'\n",
"11. 'ntree'\n",
"12. 'mtry'\n",
"13. 'forest'\n",
"14. 'coefs'\n",
"15. 'y'\n",
"16. 'test'\n",
"17. 'inbag'\n",
"\n",
"\n"
],
"text/plain": [
" [1] \"call\" \"type\" \"predicted\" \"mse\" \n",
" [5] \"rsq\" \"oob.times\" \"importance\" \"importanceSD\" \n",
" [9] \"localImportance\" \"proximity\" \"ntree\" \"mtry\" \n",
"[13] \"forest\" \"coefs\" \"y\" \"test\" \n",
"[17] \"inbag\" "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table>\n",
"<thead><tr><th></th><th scope=col>%IncMSE</th><th scope=col>IncNodePurity</th></tr></thead>\n",
"<tbody>\n",
"\t<tr><th scope=row>NEAR BAY</th><td>4.450369e+08</td><td>1.391014e+12</td></tr>\n",
"\t<tr><th scope=row>&lt;1H OCEAN</th><td>1.427100e+09</td><td>4.338351e+12</td></tr>\n",
"\t<tr><th scope=row>INLAND</th><td>3.703060e+09</td><td>3.116839e+13</td></tr>\n",
"\t<tr><th scope=row>NEAR OCEAN</th><td>4.262002e+08</td><td>2.119833e+12</td></tr>\n",
"\t<tr><th scope=row>ISLAND</th><td>6.761432e+04</td><td>1.682004e+10</td></tr>\n",
"\t<tr><th scope=row>longitude</th><td>6.686360e+09</td><td>2.580455e+13</td></tr>\n",
"\t<tr><th scope=row>latitude</th><td>5.375308e+09</td><td>2.235183e+13</td></tr>\n",
"\t<tr><th scope=row>housing_median_age</th><td>1.062666e+09</td><td>9.808021e+12</td></tr>\n",
"\t<tr><th scope=row>population</th><td>1.081431e+09</td><td>7.441795e+12</td></tr>\n",
"\t<tr><th scope=row>households</th><td>1.193537e+09</td><td>7.897400e+12</td></tr>\n",
"\t<tr><th scope=row>median_income</th><td>8.325735e+09</td><td>7.167956e+13</td></tr>\n",
"\t<tr><th scope=row>mean_bedrooms</th><td>4.083603e+08</td><td>7.597531e+12</td></tr>\n",
"\t<tr><th scope=row>mean_rooms</th><td>1.890830e+09</td><td>2.155527e+13</td></tr>\n",
"</tbody>\n",
"</table>\n"
],
"text/latex": [
"\\begin{tabular}{r|ll}\n",
" & \\%IncMSE & IncNodePurity\\\\\n",
"\\hline\n",
"\tNEAR BAY & 4.450369e+08 & 1.391014e+12\\\\\n",
"\t<1H OCEAN & 1.427100e+09 & 4.338351e+12\\\\\n",
"\tINLAND & 3.703060e+09 & 3.116839e+13\\\\\n",
"\tNEAR OCEAN & 4.262002e+08 & 2.119833e+12\\\\\n",
"\tISLAND & 6.761432e+04 & 1.682004e+10\\\\\n",
"\tlongitude & 6.686360e+09 & 2.580455e+13\\\\\n",
"\tlatitude & 5.375308e+09 & 2.235183e+13\\\\\n",
"\thousing\\_median\\_age & 1.062666e+09 & 9.808021e+12\\\\\n",
"\tpopulation & 1.081431e+09 & 7.441795e+12\\\\\n",
"\thouseholds & 1.193537e+09 & 7.897400e+12\\\\\n",
"\tmedian\\_income & 8.325735e+09 & 7.167956e+13\\\\\n",
"\tmean\\_bedrooms & 4.083603e+08 & 7.597531e+12\\\\\n",
"\tmean\\_rooms & 1.890830e+09 & 2.155527e+13\\\\\n",
"\\end{tabular}\n"
],
"text/markdown": [
"\n",
"| <!--/--> | %IncMSE | IncNodePurity | \n",
"|---|---|---|---|---|---|---|---|---|---|---|---|---|\n",
"| NEAR BAY | 4.450369e+08 | 1.391014e+12 | \n",
"| <1H OCEAN | 1.427100e+09 | 4.338351e+12 | \n",
"| INLAND | 3.703060e+09 | 3.116839e+13 | \n",
"| NEAR OCEAN | 4.262002e+08 | 2.119833e+12 | \n",
"| ISLAND | 6.761432e+04 | 1.682004e+10 | \n",
"| longitude | 6.686360e+09 | 2.580455e+13 | \n",
"| latitude | 5.375308e+09 | 2.235183e+13 | \n",
"| housing_median_age | 1.062666e+09 | 9.808021e+12 | \n",
"| population | 1.081431e+09 | 7.441795e+12 | \n",
"| households | 1.193537e+09 | 7.897400e+12 | \n",
"| median_income | 8.325735e+09 | 7.167956e+13 | \n",
"| mean_bedrooms | 4.083603e+08 | 7.597531e+12 | \n",
"| mean_rooms | 1.890830e+09 | 2.155527e+13 | \n",
"\n",
"\n"
],
"text/plain": [
" %IncMSE IncNodePurity\n",
"NEAR BAY 4.450369e+08 1.391014e+12 \n",
"<1H OCEAN 1.427100e+09 4.338351e+12 \n",
"INLAND 3.703060e+09 3.116839e+13 \n",
"NEAR OCEAN 4.262002e+08 2.119833e+12 \n",
"ISLAND 6.761432e+04 1.682004e+10 \n",
"longitude 6.686360e+09 2.580455e+13 \n",
"latitude 5.375308e+09 2.235183e+13 \n",
"housing_median_age 1.062666e+09 9.808021e+12 \n",
"population 1.081431e+09 7.441795e+12 \n",
"households 1.193537e+09 7.897400e+12 \n",
"median_income 8.325735e+09 7.167956e+13 \n",
"mean_bedrooms 4.083603e+08 7.597531e+12 \n",
"mean_rooms 1.890830e+09 2.155527e+13 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<dl class=dl-horizontal>\n",
"\t<dt>median_income</dt>\n",
"\t\t<dd>8325735051.79272</dd>\n",
"\t<dt>longitude</dt>\n",
"\t\t<dd>6686359889.42546</dd>\n",
"\t<dt>latitude</dt>\n",
"\t\t<dd>5375307640.79577</dd>\n",
"\t<dt>INLAND</dt>\n",
"\t\t<dd>3703059838.7179</dd>\n",
"\t<dt>mean_rooms</dt>\n",
"\t\t<dd>1890829634.71922</dd>\n",
"\t<dt>&lt;1H OCEAN</dt>\n",
"\t\t<dd>1427100136.39297</dd>\n",
"\t<dt>households</dt>\n",
"\t\t<dd>1193537183.56301</dd>\n",
"\t<dt>population</dt>\n",
"\t\t<dd>1081430788.36861</dd>\n",
"\t<dt>housing_median_age</dt>\n",
"\t\t<dd>1062665767.43207</dd>\n",
"\t<dt>NEAR BAY</dt>\n",
"\t\t<dd>445036940.211241</dd>\n",
"\t<dt>NEAR OCEAN</dt>\n",
"\t\t<dd>426200169.288826</dd>\n",
"\t<dt>mean_bedrooms</dt>\n",
"\t\t<dd>408360297.039669</dd>\n",
"\t<dt>ISLAND</dt>\n",
"\t\t<dd>67614.3178751267</dd>\n",
"</dl>\n"
],
"text/latex": [
"\\begin{description*}\n",
"\\item[median\\textbackslash{}\\_income] 8325735051.79272\n",
"\\item[longitude] 6686359889.42546\n",
"\\item[latitude] 5375307640.79577\n",
"\\item[INLAND] 3703059838.7179\n",
"\\item[mean\\textbackslash{}\\_rooms] 1890829634.71922\n",
"\\item[<1H OCEAN] 1427100136.39297\n",
"\\item[households] 1193537183.56301\n",
"\\item[population] 1081430788.36861\n",
"\\item[housing\\textbackslash{}\\_median\\textbackslash{}\\_age] 1062665767.43207\n",
"\\item[NEAR BAY] 445036940.211241\n",
"\\item[NEAR OCEAN] 426200169.288826\n",
"\\item[mean\\textbackslash{}\\_bedrooms] 408360297.039669\n",
"\\item[ISLAND] 67614.3178751267\n",
"\\end{description*}\n"
],
"text/markdown": [
"median_income\n",
": 8325735051.79272longitude\n",
": 6686359889.42546latitude\n",
": 5375307640.79577INLAND\n",
": 3703059838.7179mean_rooms\n",
": 1890829634.71922&amp;lt;1H OCEAN\n",
": 1427100136.39297households\n",
": 1193537183.56301population\n",
": 1081430788.36861housing_median_age\n",
": 1062665767.43207NEAR BAY\n",
": 445036940.211241NEAR OCEAN\n",
": 426200169.288826mean_bedrooms\n",
": 408360297.039669ISLAND\n",
": 67614.3178751267\n",
"\n"
],
"text/plain": [
" median_income longitude latitude INLAND \n",
" 8.325735e+09 6.686360e+09 5.375308e+09 3.703060e+09 \n",
" mean_rooms <1H OCEAN households population \n",
" 1.890830e+09 1.427100e+09 1.193537e+09 1.081431e+09 \n",
"housing_median_age NEAR BAY NEAR OCEAN mean_bedrooms \n",
" 1.062666e+09 4.450369e+08 4.262002e+08 4.083603e+08 \n",
" ISLAND \n",
" 6.761432e+04 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"48973.7871553834"
],
"text/latex": [
"48973.7871553834"
],
"text/markdown": [
"48973.7871553834"
],
"text/plain": [
"[1] 48973.79"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"48392.8697349542"
],
"text/latex": [
"48392.8697349542"
],
"text/markdown": [
"48392.8697349542"
],
"text/plain": [
"[1] 48392.87"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"########\n",
"# Random Forest Model\n",
"########\n",
"library(randomForest)\n",
"rf_model = randomForest(train_x, y = train_y , ntree = 500, importance = TRUE)\n",
"\n",
"names(rf_model) #these are all the different things you can call from the model.\n",
"\n",
"importance_dat = rf_model$importance\n",
"importance_dat\n",
"\n",
"sorted_predictors = sort(importance_dat[,1], decreasing=TRUE)\n",
"sorted_predictors\n",
"\n",
"oob_prediction = predict(rf_model) #leaving out a data source forces OOB predictions\n",
"\n",
"#you may have noticed that this is avaliable using the $mse in the model options.\n",
"#but this way we learn stuff!\n",
"train_mse = mean(as.numeric((oob_prediction - train_y)^2))\n",
"oob_rmse = sqrt(train_mse)\n",
"oob_rmse\n",
"\n",
"\n",
"y_pred_rf = predict(rf_model , test_x)\n",
"test_mse = mean(((y_pred_rf - test_y)^2))\n",
"test_rmse = sqrt(test_mse)\n",
"test_rmse # ~48620"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "b1f5a79e-afa0-4a1d-b195-44c6aed9b5d5",
"_uuid": "c0ab7db6aa402d5a2ca9370f1089b6583a7dfac6"
},
"source": [
"$48392 is the test error benchmark based off that random forest run."
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "7054f0a3-b2ed-4e20-9b2d-f12894fe1f77",
"_uuid": "9c03100b6351a3abc3b8ba656b7876c43f8c9e84"
},
"source": [
"## 2b. Gradient Boosting\n",
"\n",
"Gradient boosting is an ensemble supervised machine learning model that builds up the concept of the random forest algorithm we explored last week. Recall that for a random forest we spawned 500 decision trees and took the mean of their predictions to get a 'wisdom of the crowd' effect and arrive at a more accurate prediction than any one tree would provide.\n",
"\n",
"Here we use the Extreme Gradient Boosting library to implement this in R. Note the 'Extreme' in extreme gradient boosting refers to the computational efficiency. The algorithm could more accurately could be described as 'regularized gradient boosting'.\n",
"\n",
"### Gradient Boosting - Algorithm details\n",
"\n",
"Extreme gradient boosting also builds a forest of trees, but does so in an additive manner. The algorithm iteratively builds trees that minimize the error, and thereby descends towards an optimal set of predictive trees. Already learned trees are kept, and new trees are added one after another to minimize the objective function (error in predictions). The trees are grown sequentially: each tree is grown using information from previously grown trees. Each tree is fit on a modified version of the original data set based on the previous trees built.\n",
"\n",
"The trees are accompanied by a regularization paramater to avoid overfit. \n",
"\n",
"One difference between boosting and random forests: in boosting, because the growth of a particular \n",
"tree takes into account the other trees that have already been grown, smaller trees are typically sufficient (less splits and depth). "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"_cell_guid": "fcbe74d3-56cf-45a1-ae35-a8508ce890c5",
"_uuid": "5b1512d60b97c1e416ebc19602dabb7116677afd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1]\ttrain-rmse:171405.234375\ttest-rmse:173186.609375 \n",
"Multiple eval metrics are present. Will use test_rmse for early stopping.\n",
"Will train until test_rmse hasn't improved in 50 rounds.\n",
"\n",
"Stopping. Best iteration:\n",
"[56]\ttrain-rmse:21571.712891\ttest-rmse:47723.835938\n",
"\n"
]
}
],
"source": [
"######\n",
"# XG Boost\n",
"######\n",
"# see the docs: http://cran.fhcrc.org/web/packages/xgboost/vignettes/xgboost.pdf\n",
"library(xgboost)\n",
"\n",
"#put into the xgb matrix format\n",
"dtrain = xgb.DMatrix(data = as.matrix(train_x), label = train_y )\n",
"dtest = xgb.DMatrix(data = as.matrix(test_x), label = test_y)\n",
"\n",
"# these are the datasets the rmse is evaluated for at each iteration\n",
"watchlist = list(train=dtrain, test=dtest)\n",
"\n",
"# try 1 - off a set of paramaters I know work pretty well for most stuff\n",
"\n",
"bst = xgb.train(data = dtrain, \n",
" max.depth = 8, \n",
" eta = 0.3, \n",
" nthread = 2, \n",
" nround = 1000, \n",
" watchlist = watchlist, \n",
" objective = \"reg:linear\", \n",
" early_stopping_rounds = 50,\n",
" print_every_n = 500)"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "bc1cc2a8-40c5-43a3-9335-a961842a083e",
"_uuid": "bf791751c4e76a70f704df377d267e0b8ef3468d"
},
"source": [
"So our first run there gets a rmse of $47723, an improvement over our benchmark model. That isn't the end of the story though, we can try to squeak out further improvements through 'hyperparameter tuning'. A hyperparameter is one of the mutable options that we pass to the algorithm along with our data.\n",
"\n",
"## 2c. Tuning the algorithm - hyperparameters for xgboost\n",
"\n",
"Boosting has 3 tuning paramaters that we can focus on\n",
"\n",
"1. The number of trees. Here we use a good trick, instead of specifying an exact number, we give the algorithm a big number (nround = 10000) and the param (early_stopping_rounds = 50). This effectively means: 'keep iteratively growing trees until you have 10,000 of them, or stop early if the scores haven't improved for the last 50 rounds'.\n",
"2. The shrinkage parameter λ (eta in the params), a small positive number. This controls the rate at which boosting learns. Typical values are 0.01 or 0.001, and the right choice can depend on the problem. Very small λ can require using a very large value of B in order to achieve good performance.\n",
"3. The number of splits in each tree, which controls the complexity of the boosted ensemble (controlled with max.depth).\n",
"\n",
"### Here we try a 'slower learning' model. The up and down weights for each iteration are smaller we also use more iterations to account for the fact that the model will take longer to learn."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"_cell_guid": "a78653c4-d767-4fa0-823f-0ea228665430",
"_uuid": "f5e57221c9ca151ad8186d8bd4070820cef860fc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1]\ttrain-rmse:234537.453125\ttest-rmse:235377.093750 \n",
"Multiple eval metrics are present. Will use test_rmse for early stopping.\n",
"Will train until test_rmse hasn't improved in 50 rounds.\n",
"\n",
"[501]\ttrain-rmse:49887.167969\ttest-rmse:53810.851562 \n",
"[1001]\ttrain-rmse:43695.054688\ttest-rmse:49724.628906 \n",
"[1501]\ttrain-rmse:39971.750000\ttest-rmse:48036.140625 \n",
"[2001]\ttrain-rmse:37479.328125\ttest-rmse:47158.046875 \n",
"[2501]\ttrain-rmse:35481.843750\ttest-rmse:46556.929688 \n",
"[3001]\ttrain-rmse:33934.000000\ttest-rmse:46190.351562 \n",
"[3501]\ttrain-rmse:32535.261719\ttest-rmse:45904.425781 \n",
"[4001]\ttrain-rmse:31228.515625\ttest-rmse:45617.722656 \n",
"[4501]\ttrain-rmse:30171.023438\ttest-rmse:45451.519531 \n",
"[5001]\ttrain-rmse:29281.529297\ttest-rmse:45383.992188 \n",
"[5501]\ttrain-rmse:28471.605469\ttest-rmse:45289.558594 \n",
"[6001]\ttrain-rmse:27689.521484\ttest-rmse:45230.863281 \n",
"Stopping. Best iteration:\n",
"[5959]\ttrain-rmse:27752.142578\ttest-rmse:45225.968750\n",
"\n"
]
}
],
"source": [
"bst_slow = xgb.train(data = dtrain, \n",
" max.depth=5, \n",
" eta = 0.01, \n",
" nthread = 2, \n",
" nround = 10000, \n",
" watchlist = watchlist, \n",
" objective = \"reg:linear\", \n",
" early_stopping_rounds = 50,\n",
" print_every_n = 500)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"_cell_guid": "6dbd1717-7124-4533-9c67-217c9ff3f5fc",
"_uuid": "abfd8ae94d75b904fe6775474a2ae409b50dc869"
},
"outputs": [
{
"data": {
"text/html": [
"<strong>test-rmse:</strong> 0.93457531720119"
],
"text/latex": [
"\\textbf{test-rmse:} 0.93457531720119"
],
"text/markdown": [
"**test-rmse:** 0.93457531720119"
],
"text/plain": [
"test-rmse \n",
"0.9345753 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"rf_benchmark = 48392\n",
"\n",
"bst_slow$best_score / rf_benchmark"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "97291bb2-8919-4bc3-8b63-6adc2f58276b",
"_uuid": "201969793f39d71add07b8137317f221413cadbc"
},
"source": [
"Things to note:\n",
"- 6009 iterations were run, and it backtracked to 5959 to get the best one thanks to our early stopping rounds parameter.\n",
"- So that is an improvement of ~6.5% in rmse error over last week. Yay lets go home. Wait! What we have done here is fit to the training set and the test set at the same time (which can, and likelt has, lead to model overfit). \n",
"\n",
"## Problems to address\n",
"\n",
"### 1. We need to work with a validation set and only at the end evaluate the model performance against the test set.\n",
"Remember our test set should be withheld to evaluate the model on data it hasn't seen. By iteratively checking the train and test rmse above, we have violated that rule and so we cannot say how accurate our model is on data it hasn't seen. To do this we need a validation set (essentially a second test set) that the model can peek at after each iteration to see how well it is performing or external data, then with the final version we can make the assessment vs. the test set.\n",
"\n",
"### 2 .If we kept tweaking one hyperparameter, waiting to see the result and then repeating the process we will be here forever. We need to speed this up in a systematic fashion\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "a5bcf778-1101-4165-a1bb-5c57038c8684",
"_uuid": "b943eaf9e59a94b62d4e1e1883775f703adcf068"
},
"source": [
"## 2d. A validation set\n",
"\n",
"validation set - Another subset of our data that is witheld from the training algorithm, but compared against at each iteration to see how the model is performing.\n",
"\n",
"Here we make this through the same method as the test set, by sampling 20% of the remaining training set and passing this into the xgboost watchlist. The algorithm will watch the rmse of the training (which it can learn from) and the validation (which it can't learn parameters from, only see the rmse outcome) and continue until one is no longer improving\n",
"\n",
"\n",
"### Note on tidyverse - I you switch the first cell to Karl's tidyverse version, then you need to make the following switch\n",
"Karl's notebook: https://www.kaggle.com/karlcottenie/introduction-to-machine-learning-in-r-tutorial\n",
"\n",
"train_y = train_t[,'median_house_value']\n",
"When I run the tidy cleaning for step one this outputs a dataframe like column, but the algorithm needs a plain vector of numbers for the y labels.\n",
"\n",
"The code needs to switch to:\n",
"\n",
"train_y = pull(train_t, median_house_value)\n",
"\n",
"In order to un-tibble the data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"_cell_guid": "925d6049-28f3-4c0c-9ad3-50662652bb5a",
"_uuid": "ea90660e47d1c415116d19d3ebe572d4f85f0750"
},
"outputs": [
{
"data": {
"text/html": [
"<ol class=list-inline>\n",
"\t<li>476400</li>\n",
"\t<li>409900</li>\n",
"\t<li>235500</li>\n",
"\t<li>74700</li>\n",
"\t<li>171100</li>\n",
"\t<li>188600</li>\n",
"\t<li>94400</li>\n",
"\t<li>211100</li>\n",
"\t<li>450000</li>\n",
"\t<li>228100</li>\n",
"</ol>\n"
],
"text/latex": [
"\\begin{enumerate*}\n",
"\\item 476400\n",
"\\item 409900\n",
"\\item 235500\n",
"\\item 74700\n",
"\\item 171100\n",
"\\item 188600\n",
"\\item 94400\n",
"\\item 211100\n",
"\\item 450000\n",
"\\item 228100\n",
"\\end{enumerate*}\n"
],
"text/markdown": [
"1. 476400\n",
"2. 409900\n",
"3. 235500\n",
"4. 74700\n",
"5. 171100\n",
"6. 188600\n",
"7. 94400\n",
"8. 211100\n",
"9. 450000\n",
"10. 228100\n",
"\n",
"\n"
],
"text/plain": [
" [1] 476400 409900 235500 74700 171100 188600 94400 211100 450000 228100"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"####\n",
"# Proper use - validation set\n",
"####\n",
"\n",
"\n",
"sample = sample.int(n = nrow(train), size = floor(.8*nrow(train)), replace = F)\n",
"\n",
"train_t = train[sample, ] #just the samples\n",
"valid = train[-sample, ] #everything but the samples\n",
"\n",
"train_y = train_t[,'median_house_value']\n",
"\n",
"#if tidyverse was used, dplyr pull function solves the problem:\n",
"#train_y = pull(train_t, median_house_value)\n",
"train_x = train_t[, names(train_t) !='median_house_value']\n",
"\n",
"valid_y = valid[,'median_house_value']\n",
"valid_x = valid[, names(train_t) !='median_house_value']\n",
"\n",
"train_y[1:10]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"_cell_guid": "592838aa-ffbe-4c46-a9f5-f3aad4d80311",
"_uuid": "f3df270cc7818d3cfbe4a9dd035c5c4803e12fec",
"collapsed": true
},
"outputs": [],
"source": [
"gb_train = xgb.DMatrix(data = as.matrix(train_x), label = train_y )\n",
"gb_valid = xgb.DMatrix(data = as.matrix(valid_x), label = valid_y )\n",
"#in jupyter the label needs to be in an as.matrix() or I get an error? subtle and annoying differences\n",
"\n",
"# train xgb, evaluating against the validation\n",
"watchlist = list(train = gb_train, valid = gb_valid)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "3ef5958f-a423-4cb5-95dd-c47e10fc12c7",
"_uuid": "d4b47ef2c5af4a85acede725be9e49c95fc3cd4c"
},
"source": [
"We then run the xgboost algorithm again and after training we evaluate on the test data."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"_cell_guid": "f6b8034b-b934-4862-b95f-4a07f0f414a5",
"_uuid": "024597708a1a7005855c68f6c58576a166b1db0b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1]\ttrain-rmse:234965.703125\tvalid-rmse:232542.093750 \n",
"Multiple eval metrics are present. Will use valid_rmse for early stopping.\n",
"Will train until valid_rmse hasn't improved in 50 rounds.\n",
"\n",
"[501]\ttrain-rmse:23408.917969\tvalid-rmse:48958.464844 \n",
"[1001]\ttrain-rmse:15327.182617\tvalid-rmse:47956.554688 \n",
"[1501]\ttrain-rmse:11868.001953\tvalid-rmse:47733.187500 \n",
"[2001]\ttrain-rmse:9716.176758\tvalid-rmse:47619.566406 \n",
"[2501]\ttrain-rmse:8043.468750\tvalid-rmse:47542.125000 \n",
"Stopping. Best iteration:\n",
"[2626]\ttrain-rmse:7657.586914\tvalid-rmse:47534.148438\n",
"\n"
]
}
],
"source": [
"bst_slow = xgb.train(data= gb_train, \n",
" max.depth = 10, \n",
" eta = 0.01, \n",
" nthread = 2, \n",
" nround = 10000, \n",
" watchlist = watchlist, \n",
" objective = \"reg:linear\", \n",
" early_stopping_rounds = 50,\n",
" print_every_n = 500)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"_cell_guid": "7b48a2ad-1bd7-4edf-98c4-40717c05264e",
"_uuid": "ab102f1aa28f9858c33a22bba59422de7a924f3c"
},
"outputs": [
{
"data": {
"text/html": [
"47331.1581083117"
],
"text/latex": [
"47331.1581083117"
],
"text/markdown": [
"47331.1581083117"
],
"text/plain": [
"[1] 47331.16"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"# recall we ran the following to get the test data in the right format:\n",
"# dtest = xgb.DMatrix(data = as.matrix(test_x), label = test_y)\n",
"# here I have it with the label taken off, just to remind us its external data xgb would ignore the label though during predictions\n",
"dtest = xgb.DMatrix(data = as.matrix(test_x))\n",
"\n",
"#test the model on truly external data\n",
"\n",
"y_hat_valid = predict(bst_slow, dtest)\n",
"\n",
"test_mse = mean(((y_hat_valid - test_y)^2))\n",
"test_rmse = sqrt(test_mse)\n",
"test_rmse \n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"_cell_guid": "6b1c6bf8-b702-4e28-87f2-6e25f96169d8",
"_uuid": "94509698917c57d109ceab580fab90452ca71480"
},
"outputs": [
{
"data": {
"text/html": [
"0.978078155652003"
],
"text/latex": [
"0.978078155652003"
],
"text/markdown": [
"0.978078155652003"
],
"text/plain": [
"[1] 0.9780782"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"test_rmse/rf_benchmark"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "4b93c96d-2879-406d-817f-c09788278812",
"_uuid": "40940eabb4164cc958061073b29511eb73d88429"
},
"source": [
"This is higher then on the first run, but we can be confident that the improved score is not due to overfit thanks to our use of a validation set! A lower rmse isn't necessarily better if it comes at the cose of overfit, we now have more confidence in external predictions.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "97e930dc-413e-4ae4-bb99-1abfb5a58cd5",
"_uuid": "bea7bb14ad8305e41337d384a98ebfba07ab71b5"
},
"source": [
"## 3a. Grid Search to find the best hyperparameter combinations"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"_cell_guid": "ee5152e3-4228-484a-97a1-dd0c24fd97bd",
"_uuid": "2f9e4be462a5469a7072c596b0e2187fc55ae193"
},
"outputs": [
{
"data": {
"text/html": [
"<dl>\n",
"\t<dt>$max_depth</dt>\n",
"\t\t<dd>7</dd>\n",
"\t<dt>$eta</dt>\n",
"\t\t<dd>0.01</dd>\n",
"\t<dt>$nthread</dt>\n",
"\t\t<dd>2</dd>\n",
"\t<dt>$objective</dt>\n",
"\t\t<dd>'reg:linear'</dd>\n",
"\t<dt>$silent</dt>\n",
"\t\t<dd>1</dd>\n",
"</dl>\n"
],
"text/latex": [
"\\begin{description}\n",
"\\item[\\$max\\_depth] 7\n",
"\\item[\\$eta] 0.01\n",
"\\item[\\$nthread] 2\n",
"\\item[\\$objective] 'reg:linear'\n",
"\\item[\\$silent] 1\n",
"\\end{description}\n"
],
"text/markdown": [
"$max_depth\n",
": 7\n",
"$eta\n",
": 0.01\n",
"$nthread\n",
": 2\n",
"$objective\n",
": 'reg:linear'\n",
"$silent\n",
": 1\n",
"\n",
"\n"
],
"text/plain": [
"$max_depth\n",
"[1] 7\n",
"\n",
"$eta\n",
"[1] 0.01\n",
"\n",
"$nthread\n",
"[1] 2\n",
"\n",
"$objective\n",
"[1] \"reg:linear\"\n",
"\n",
"$silent\n",
"[1] 1\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<strong>valid-rmse:</strong> 47117.835938"
],
"text/latex": [
"\\textbf{valid-rmse:} 47117.835938"
],
"text/markdown": [
"**valid-rmse:** 47117.835938"
],
"text/plain": [
"valid-rmse \n",
" 47117.84 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"###\n",
"# Grid search first principles \n",
"###\n",
"\n",
"max.depths = c(7, 9)\n",
"etas = c(0.01, 0.001)\n",
"\n",
"best_params = 0\n",
"best_score = 0\n",
"\n",
"count = 1\n",
"for( depth in max.depths ){\n",
" for( num in etas){\n",
"\n",
" bst_grid = xgb.train(data = gb_train, \n",
" max.depth = depth, \n",
" eta=num, \n",
" nthread = 2, \n",
" nround = 10000, \n",
" watchlist = watchlist, \n",
" objective = \"reg:linear\", \n",
" early_stopping_rounds = 50, \n",
" verbose=0)\n",
"\n",
" if(count == 1){\n",
" best_params = bst_grid$params\n",
" best_score = bst_grid$best_score\n",
" count = count + 1\n",
" }\n",
" else if( bst_grid$best_score < best_score){\n",
" best_params = bst_grid$params\n",
" best_score = bst_grid$best_score\n",
" }\n",
" }\n",
"}\n",
"\n",
"best_params\n",
"best_score\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"_cell_guid": "2b0208b1-72d6-4f69-906a-598a1a776d09",
"_uuid": "3364136560f1e1717ceb933e11e110ae4ce70f89"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1]\ttrain-rmse:234989.328125\tvalid-rmse:232543.312500 \n",
"Multiple eval metrics are present. Will use valid_rmse for early stopping.\n",
"Will train until valid_rmse hasn't improved in 50 rounds.\n",
"\n",
"[501]\ttrain-rmse:38440.367188\tvalid-rmse:50656.929688 \n",
"[1001]\ttrain-rmse:31137.027344\tvalid-rmse:48560.308594 \n",
"[1501]\ttrain-rmse:26716.404297\tvalid-rmse:47650.949219 \n",
"[2001]\ttrain-rmse:23859.535156\tvalid-rmse:47432.820312 \n",
"[2501]\ttrain-rmse:21716.740234\tvalid-rmse:47233.796875 \n",
"[3001]\ttrain-rmse:19926.523438\tvalid-rmse:47135.058594 \n",
"Stopping. Best iteration:\n",
"[3109]\ttrain-rmse:19554.800781\tvalid-rmse:47117.835938\n",
"\n"
]
},
{
"data": {
"text/html": [
"46146.4750834592"
],
"text/latex": [
"46146.4750834592"
],
"text/markdown": [
"46146.4750834592"
],
"text/plain": [
"[1] 46146.48"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# max_depth of 7, eta of 0.01\n",
"bst_tuned = xgb.train( data = gb_train, \n",
" max.depth = 7, \n",
" eta = 0.01, \n",
" nthread = 2, \n",
" nround = 10000, \n",
" watchlist = watchlist, \n",
" objective = \"reg:linear\", \n",
" early_stopping_rounds = 50,\n",
" print_every_n = 500)\n",
"\n",
"y_hat_xgb_grid = predict(bst_tuned, dtest)\n",
"\n",
"test_mse = mean(((y_hat_xgb_grid - test_y)^2))\n",
"test_rmse = sqrt(test_mse)\n",
"test_rmse # test-rmse: 46675\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"_cell_guid": "b5d75011-c262-47be-8b5e-969b6cb2b763",
"_uuid": "c1d656c1f8d2f43b026961474181fdc95686a2df"
},
"outputs": [
{
"data": {
"text/html": [
"0.953597187209854"
],
"text/latex": [
"0.953597187209854"
],
"text/markdown": [
"0.953597187209854"
],
"text/plain": [
"[1] 0.9535972"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"test_rmse/rf_benchmark"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "da75f66b-1889-4fc4-9281-cd572cfd6682",
"_uuid": "abb75959388ce6cb272189fe1a812f27204bae59"
},
"source": [
"By tuning the hyperparamaters we have made a slightly greater improvement over random forest. It is however only a small improvement over the non tuned xgboost model. But these performance differences do matter in some circumstances!\n",
"\n",
"\n",
"## 3b. Efficiently tweak the hyperparamaters using a grid search/cross-validation\n",
"\n",
"The caret package (short for classification and regression training) is used to simplify the grid search we just implemented. We can just pass it a grid of hyperparameter combinations and it will run all the combinations and do a cross-validation for each (so no validation set needed)\n",
"\n",
"[caret info](http://topepo.github.io/caret/index.html)\n",
"\n",
"Similar to the tidyverse it works really well... but you have to learn all the code tricks!\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"_cell_guid": "2980bb24-e53e-430d-92dc-69d4605b5bee",
"_uuid": "02d8b0fa7b5f09013205d785c805ab07bb513c4d"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading required package: lattice\n",
"\n",
"Attaching package: ‘caret’\n",
"\n",
"The following object is masked from ‘package:purrr’:\n",
"\n",
" lift\n",
"\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
"<thead><tr><th scope=col>model</th><th scope=col>parameter</th><th scope=col>label</th><th scope=col>forReg</th><th scope=col>forClass</th><th scope=col>probModel</th></tr></thead>\n",
"<tbody>\n",
"\t<tr><td>xgbLinear </td><td>nrounds </td><td># Boosting Iterations</td><td>TRUE </td><td>TRUE </td><td>TRUE </td></tr>\n",
"\t<tr><td>xgbLinear </td><td>lambda </td><td>L2 Regularization </td><td>TRUE </td><td>TRUE </td><td>TRUE </td></tr>\n",
"\t<tr><td>xgbLinear </td><td>alpha </td><td>L1 Regularization </td><td>TRUE </td><td>TRUE </td><td>TRUE </td></tr>\n",
"\t<tr><td>xgbLinear </td><td>eta </td><td>Learning Rate </td><td>TRUE </td><td>TRUE </td><td>TRUE </td></tr>\n",
"</tbody>\n",
"</table>\n"
],
"text/latex": [
"\\begin{tabular}{r|llllll}\n",
" model & parameter & label & forReg & forClass & probModel\\\\\n",
"\\hline\n",
"\t xgbLinear & nrounds & \\# Boosting Iterations & TRUE & TRUE & TRUE \\\\\n",
"\t xgbLinear & lambda & L2 Regularization & TRUE & TRUE & TRUE \\\\\n",
"\t xgbLinear & alpha & L1 Regularization & TRUE & TRUE & TRUE \\\\\n",
"\t xgbLinear & eta & Learning Rate & TRUE & TRUE & TRUE \\\\\n",
"\\end{tabular}\n"
],
"text/markdown": [
"\n",
"model | parameter | label | forReg | forClass | probModel | \n",
"|---|---|---|---|\n",
"| xgbLinear | nrounds | # Boosting Iterations | TRUE | TRUE | TRUE | \n",
"| xgbLinear | lambda | L2 Regularization | TRUE | TRUE | TRUE | \n",
"| xgbLinear | alpha | L1 Regularization | TRUE | TRUE | TRUE | \n",
"| xgbLinear | eta | Learning Rate | TRUE | TRUE | TRUE | \n",
"\n",
"\n"
],
"text/plain": [
" model parameter label forReg forClass probModel\n",
"1 xgbLinear nrounds # Boosting Iterations TRUE TRUE TRUE \n",
"2 xgbLinear lambda L2 Regularization TRUE TRUE TRUE \n",
"3 xgbLinear alpha L1 Regularization TRUE TRUE TRUE \n",
"4 xgbLinear eta Learning Rate TRUE TRUE TRUE "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table>\n",
"<thead><tr><th scope=col>nrounds</th><th scope=col>eta</th><th scope=col>lambda</th><th scope=col>alpha</th></tr></thead>\n",
"<tbody>\n",
"\t<tr><td>1000 </td><td>1e-02</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>2000 </td><td>1e-02</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>3000 </td><td>1e-02</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>4000 </td><td>1e-02</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>1000 </td><td>1e-03</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>2000 </td><td>1e-03</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>3000 </td><td>1e-03</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>4000 </td><td>1e-03</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>1000 </td><td>1e-04</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>2000 </td><td>1e-04</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>3000 </td><td>1e-04</td><td>1 </td><td>0 </td></tr>\n",
"\t<tr><td>4000 </td><td>1e-04</td><td>1 </td><td>0 </td></tr>\n",
"</tbody>\n",
"</table>\n"
],
"text/latex": [
"\\begin{tabular}{r|llll}\n",
" nrounds & eta & lambda & alpha\\\\\n",
"\\hline\n",
"\t 1000 & 1e-02 & 1 & 0 \\\\\n",
"\t 2000 & 1e-02 & 1 & 0 \\\\\n",
"\t 3000 & 1e-02 & 1 & 0 \\\\\n",
"\t 4000 & 1e-02 & 1 & 0 \\\\\n",
"\t 1000 & 1e-03 & 1 & 0 \\\\\n",
"\t 2000 & 1e-03 & 1 & 0 \\\\\n",
"\t 3000 & 1e-03 & 1 & 0 \\\\\n",
"\t 4000 & 1e-03 & 1 & 0 \\\\\n",
"\t 1000 & 1e-04 & 1 & 0 \\\\\n",
"\t 2000 & 1e-04 & 1 & 0 \\\\\n",
"\t 3000 & 1e-04 & 1 & 0 \\\\\n",
"\t 4000 & 1e-04 & 1 & 0 \\\\\n",
"\\end{tabular}\n"
],
"text/markdown": [
"\n",
"nrounds | eta | lambda | alpha | \n",
"|---|---|---|---|---|---|---|---|---|---|---|---|\n",
"| 1000 | 1e-02 | 1 | 0 | \n",
"| 2000 | 1e-02 | 1 | 0 | \n",
"| 3000 | 1e-02 | 1 | 0 | \n",
"| 4000 | 1e-02 | 1 | 0 | \n",
"| 1000 | 1e-03 | 1 | 0 | \n",
"| 2000 | 1e-03 | 1 | 0 | \n",
"| 3000 | 1e-03 | 1 | 0 | \n",
"| 4000 | 1e-03 | 1 | 0 | \n",
"| 1000 | 1e-04 | 1 | 0 | \n",
"| 2000 | 1e-04 | 1 | 0 | \n",
"| 3000 | 1e-04 | 1 | 0 | \n",
"| 4000 | 1e-04 | 1 | 0 | \n",
"\n",
"\n"
],
"text/plain": [
" nrounds eta lambda alpha\n",
"1 1000 1e-02 1 0 \n",
"2 2000 1e-02 1 0 \n",
"3 3000 1e-02 1 0 \n",
"4 4000 1e-02 1 0 \n",
"5 1000 1e-03 1 0 \n",
"6 2000 1e-03 1 0 \n",
"7 3000 1e-03 1 0 \n",
"8 4000 1e-03 1 0 \n",
"9 1000 1e-04 1 0 \n",
"10 2000 1e-04 1 0 \n",
"11 3000 1e-04 1 0 \n",
"12 4000 1e-04 1 0 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"library(caret) \n",
"\n",
"# look up the model we are running to see the paramaters\n",
"modelLookup(\"xgbLinear\")\n",
" \n",
"# set up all the pairwise combinations\n",
"\n",
"xgb_grid_1 = expand.grid(nrounds = c(1000,2000,3000,4000) ,\n",
" eta = c(0.01, 0.001, 0.0001),\n",
" lambda = 1,\n",
" alpha = 0)\n",
"xgb_grid_1\n",
"\n",
"\n",
"#here we do one better then a validation set, we use cross validation to \n",
"#expand the amount of info we have!\n",
"xgb_trcontrol_1 = trainControl(method = \"cv\",\n",
" number = 5,\n",
" verboseIter = TRUE,\n",
" returnData = FALSE,\n",
" returnResamp = \"all\", \n",
" allowParallel = TRUE)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "302fe6c1-0bd3-485a-93a8-9992f18108cc",
"_uuid": "9e441ea3ee974911ca4378eda40fa4032b8604cd"
},
"source": [
"Train the model for each parameter combination in the grid, using CV to evaluate on multiple folds. Make sure your laptop is plugged in or else RIP battery if you've got a tiny old macbook like me.\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"_cell_guid": "5fee2a03-45f7-41d4-b0ce-0e1c5e38c9c3",
"_uuid": "06ca7c361b11ec55d3f42a033099c06aea553193"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+ Fold1: nrounds=1000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold1: nrounds=1000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=2000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold1: nrounds=2000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=3000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold1: nrounds=3000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=4000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold1: nrounds=4000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=1000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold1: nrounds=1000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=2000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold1: nrounds=2000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=3000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold1: nrounds=3000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=4000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold1: nrounds=4000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=1000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold1: nrounds=1000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=2000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold1: nrounds=2000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=3000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold1: nrounds=3000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold1: nrounds=4000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold1: nrounds=4000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=1000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold2: nrounds=1000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=2000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold2: nrounds=2000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=3000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold2: nrounds=3000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=4000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold2: nrounds=4000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=1000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold2: nrounds=1000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=2000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold2: nrounds=2000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=3000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold2: nrounds=3000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=4000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold2: nrounds=4000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=1000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold2: nrounds=1000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=2000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold2: nrounds=2000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=3000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold2: nrounds=3000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold2: nrounds=4000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold2: nrounds=4000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=1000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold3: nrounds=1000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=2000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold3: nrounds=2000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=3000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold3: nrounds=3000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=4000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold3: nrounds=4000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=1000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold3: nrounds=1000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=2000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold3: nrounds=2000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=3000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold3: nrounds=3000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=4000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold3: nrounds=4000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=1000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold3: nrounds=1000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=2000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold3: nrounds=2000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=3000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold3: nrounds=3000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold3: nrounds=4000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold3: nrounds=4000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=1000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold4: nrounds=1000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=2000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold4: nrounds=2000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=3000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold4: nrounds=3000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=4000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold4: nrounds=4000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=1000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold4: nrounds=1000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=2000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold4: nrounds=2000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=3000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold4: nrounds=3000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=4000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold4: nrounds=4000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=1000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold4: nrounds=1000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=2000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold4: nrounds=2000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=3000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold4: nrounds=3000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold4: nrounds=4000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold4: nrounds=4000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=1000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold5: nrounds=1000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=2000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold5: nrounds=2000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=3000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold5: nrounds=3000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=4000, eta=1e-02, lambda=1, alpha=0 \n",
"- Fold5: nrounds=4000, eta=1e-02, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=1000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold5: nrounds=1000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=2000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold5: nrounds=2000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=3000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold5: nrounds=3000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=4000, eta=1e-03, lambda=1, alpha=0 \n",
"- Fold5: nrounds=4000, eta=1e-03, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=1000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold5: nrounds=1000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=2000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold5: nrounds=2000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=3000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold5: nrounds=3000, eta=1e-04, lambda=1, alpha=0 \n",
"+ Fold5: nrounds=4000, eta=1e-04, lambda=1, alpha=0 \n",
"- Fold5: nrounds=4000, eta=1e-04, lambda=1, alpha=0 \n",
"Aggregating results\n",
"Selecting tuning parameters\n",
"Fitting nrounds = 1000, lambda = 1, alpha = 0, eta = 1e-04 on full training set\n"
]
},
{
"data": {
"text/html": [
"<ol class=list-inline>\n",
"\t<li>'method'</li>\n",
"\t<li>'modelInfo'</li>\n",
"\t<li>'modelType'</li>\n",
"\t<li>'results'</li>\n",
"\t<li>'pred'</li>\n",
"\t<li>'bestTune'</li>\n",
"\t<li>'call'</li>\n",
"\t<li>'dots'</li>\n",
"\t<li>'metric'</li>\n",
"\t<li>'control'</li>\n",
"\t<li>'finalModel'</li>\n",
"\t<li>'preProcess'</li>\n",
"\t<li>'trainingData'</li>\n",
"\t<li>'resample'</li>\n",
"\t<li>'resampledCM'</li>\n",
"\t<li>'perfNames'</li>\n",
"\t<li>'maximize'</li>\n",
"\t<li>'yLimits'</li>\n",
"\t<li>'times'</li>\n",
"\t<li>'levels'</li>\n",
"</ol>\n"
],
"text/latex": [
"\\begin{enumerate*}\n",
"\\item 'method'\n",
"\\item 'modelInfo'\n",
"\\item 'modelType'\n",
"\\item 'results'\n",
"\\item 'pred'\n",
"\\item 'bestTune'\n",
"\\item 'call'\n",
"\\item 'dots'\n",
"\\item 'metric'\n",
"\\item 'control'\n",
"\\item 'finalModel'\n",
"\\item 'preProcess'\n",
"\\item 'trainingData'\n",
"\\item 'resample'\n",
"\\item 'resampledCM'\n",
"\\item 'perfNames'\n",
"\\item 'maximize'\n",
"\\item 'yLimits'\n",
"\\item 'times'\n",
"\\item 'levels'\n",
"\\end{enumerate*}\n"
],
"text/markdown": [
"1. 'method'\n",
"2. 'modelInfo'\n",
"3. 'modelType'\n",
"4. 'results'\n",
"5. 'pred'\n",
"6. 'bestTune'\n",
"7. 'call'\n",
"8. 'dots'\n",
"9. 'metric'\n",
"10. 'control'\n",
"11. 'finalModel'\n",
"12. 'preProcess'\n",
"13. 'trainingData'\n",
"14. 'resample'\n",
"15. 'resampledCM'\n",
"16. 'perfNames'\n",
"17. 'maximize'\n",
"18. 'yLimits'\n",
"19. 'times'\n",
"20. 'levels'\n",
"\n",
"\n"
],
"text/plain": [
" [1] \"method\" \"modelInfo\" \"modelType\" \"results\" \"pred\" \n",
" [6] \"bestTune\" \"call\" \"dots\" \"metric\" \"control\" \n",
"[11] \"finalModel\" \"preProcess\" \"trainingData\" \"resample\" \"resampledCM\" \n",
"[16] \"perfNames\" \"maximize\" \"yLimits\" \"times\" \"levels\" "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table>\n",
"<thead><tr><th scope=col>nrounds</th><th scope=col>lambda</th><th scope=col>alpha</th><th scope=col>eta</th></tr></thead>\n",
"<tbody>\n",
"\t<tr><td>1000 </td><td>1 </td><td>0 </td><td>1e-04</td></tr>\n",
"</tbody>\n",
"</table>\n"
],
"text/latex": [
"\\begin{tabular}{r|llll}\n",
" nrounds & lambda & alpha & eta\\\\\n",
"\\hline\n",
"\t 1000 & 1 & 0 & 1e-04\\\\\n",
"\\end{tabular}\n"
],
"text/markdown": [
"\n",
"nrounds | lambda | alpha | eta | \n",
"|---|\n",
"| 1000 | 1 | 0 | 1e-04 | \n",
"\n",
"\n"
],
"text/plain": [
" nrounds lambda alpha eta \n",
"1 1000 1 0 1e-04"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"'xgbLinear'"
],
"text/latex": [
"'xgbLinear'"
],
"text/markdown": [
"'xgbLinear'"
],
"text/plain": [
"[1] \"xgbLinear\""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
" Length Class Mode \n",
"handle 1 xgb.Booster.handle externalptr\n",
"raw 2073968 -none- raw \n",
"niter 1 -none- numeric \n",
"call 6 -none- call \n",
"params 5 -none- list \n",
"callbacks 1 -none- list \n",
"feature_names 13 -none- character \n",
"xNames 13 -none- character \n",
"problemType 1 -none- character \n",
"tuneValue 4 data.frame list \n",
"obsLevels 1 -none- logical \n",
"param 1 -none- list "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"47744.1678007726"
],
"text/latex": [
"47744.1678007726"
],
"text/markdown": [
"47744.1678007726"
],
"text/plain": [
"[1] 47744.17"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"######\n",
"#below a grid-search, cross-validation xgboost model in caret\n",
"######\n",
"\n",
"\n",
"xgb_train_1 = train(x = as.matrix(train_x),\n",
" y = train_y,\n",
" trControl = xgb_trcontrol_1,\n",
" tuneGrid = xgb_grid_1,\n",
" method = \"xgbLinear\",\n",
" max.depth = 5)\n",
"\n",
"names(xgb_train_1)\n",
"xgb_train_1$bestTune\n",
"xgb_train_1$method\n",
"summary(xgb_train_1)\n",
"\n",
"\n",
"#alternatively, you can 'narrow in' on the best paramaters. Repeat the above by taking a range of options around \n",
"#the best values found and seeing if high resolution tweaks can provide even further improvements.\n",
"\n",
"xgb_cv_yhat = predict(xgb_train_1 , as.matrix(test_x))\n",
"\n",
"\n",
"test_mse = mean(((xgb_cv_yhat - test_y)^2))\n",
"test_rmse = sqrt(test_mse)\n",
"test_rmse # 47744... higher then 'by hand' grid search!\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "17ce8a75-b45a-4e87-ab98-feefbbbf916d",
"_uuid": "8dd13b59197b0f35bbdb5d192c3a72631fc1ab9a"
},
"source": [
"Cam's hypothesis on caret performance - we are not using 'early stopping rounds' here so the model isn't cutting out at the exact best point. Re-running this with a validation setup as opposed to a cv setup would allow us to implement a grid search efficiently and wind up with the best hyperparamaters. Here we also didn't tweak tree depth, so that may in fact be important for the performance gains we saw. I shall leave this as a follow up exercise for the curious.\n",
"\n",
"\n",
"## 4 Ensemble the models together \n",
"\n",
"This is a good strategy for when accuracy is more important then knowing the best predictors.\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"_cell_guid": "f27abbe7-5c81-40b8-9741-911765eef59f",
"_uuid": "c9188c7d9363ad99727fb445169e50848d4eddb9"
},
"outputs": [
{
"data": {
"text/html": [
"4128"
],
"text/latex": [
"4128"
],
"text/markdown": [
"4128"
],
"text/plain": [
"[1] 4128"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"4128"
],
"text/latex": [
"4128"
],
"text/markdown": [
"4128"
],
"text/plain": [
"[1] 4128"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"TRUE"
],
"text/latex": [
"TRUE"
],
"text/markdown": [
"TRUE"
],
"text/plain": [
"[1] TRUE"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"45407.7365122905"
],
"text/latex": [
"45407.7365122905"
],
"text/markdown": [
"45407.7365122905"
],
"text/plain": [
"[1] 45407.74"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#y_pred_rf #random forest\n",
"#y_hat_valid #xgBoost with validation\n",
"#y_hat_xgb_grid #xgBoost grid search\n",
"#xgb_cv_yhat #xgBoost caret cross validation\n",
"\n",
"length(y_hat_xgb_grid)\n",
"\n",
"\n",
"blend_pred = (y_hat_valid * .25) + (y_pred_rf * .25) + (xgb_cv_yhat * .25) + (y_hat_xgb_grid * .25)\n",
"length(blend_pred)\n",
"\n",
"length(blend_pred) == length(y_hat_xgb_grid)\n",
"\n",
"blend_test_mse = mean(((blend_pred - test_y)^2))\n",
"blend_test_rmse = sqrt(blend_test_mse)\n",
"blend_test_rmse \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"_cell_guid": "450b65ee-9fb4-4a9a-999a-1687b9c396d7",
"_uuid": "f3490fd5993e7c55058a3743c525a70bd2315c9a"
},
"source": [
"Just by averaging 4 (very similar) predictors we have dropped the rmse a few percent lower then the best scoring of the 4 models. This does come at a cost though, we now can't make accurate inferences about the best predictors! The strategy is more effective when you take a more diverse set of models and ensemble those together.\n",
"\n",
"### next step - you can grid search the weights of the ensemble to try and drop the rmse further!\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "R",
"language": "R",
"name": "ir"
},
"language_info": {
"codemirror_mode": "r",
"file_extension": ".r",
"mimetype": "text/x-r-source",
"name": "R",
"pygments_lexer": "r",
"version": "3.3.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment