class: center, middle, inverse, title-slide # Machine learning ## Decision trees ### July 18th, 2022 --- ## What is Machine Learning? -- The short version: - Machine learning (ML) is a subset of statistical learning that focuses on prediction -- The longer version: - ML focuses on constructing data-driven algorithms that *learn* the mapping between predictor variables and response variable(s). - We do not assume a parametric form for the mapping *a priori*, even if technically one can write one down *a posteriori* (e.g., by translating a tree model to a indicator-variable mathematical expression) - e.g., linear regression is NOT considered a ML algorithm since we can write down the linear equation ahead of time - e.g., random forests are considered an ML algorithm since we have what the trees will look like in advance --- ## Which algorithm is best? __That's not the right question to ask.__ (And the answer is *not* deep learning. Because if the underlying relationship between your predictors and your response is truly linear, *you do not need to apply deep learning*! Just do linear regression. Really. It's OK.) -- The right question is ask is: __why should I try different algorithms?__ -- The answer to that is that without superhuman powers, you cannot visualize the distribution of predictor variables in their native space. - Of course, you can visualize these data *in projection*, for instance when we perform EDA - And the performance of different algorithms will depend on how predictor data are distributed... --- ### Data geometry <img src="http://www.stat.cmu.edu/~pfreeman/data_geometry.png" width="70%" style="display: block; margin: auto;" /> - Two predictor variables with binary response variable: x's and o's - __LHS__: Linear boundaries that form rectangles will peform well in predicting response - __RHS__: Circular boundaries will perform better --- ## Decision trees Decision trees partition training data into __homogenous nodes / subgroups__ with similar response values -- The subgroups are found __recursively using binary partitions__ - i.e. asking a series of yes-no questions about the predictor variables We stop splitting the tree once a __stopping criteria__ has been reached (e.g. maximum depth allowed) -- For each subgroup / node predictions are made with: - Regression tree: __the average of the response values__ in the node - Classification tree: __the most popular class__ in the node -- Most popular approach is Leo Breiman's __C__lassification __A__nd __R__egression __T__ree (CART) algorithm --- ## Decision tree structure <img src="https://bradleyboehmke.github.io/HOML/images/decision-tree-terminology.png" width="100%" style="display: block; margin: auto;" /> --- ## Decision tree structure We make a prediction for an observation by __following its path along the tree__ <img src="https://bradleyboehmke.github.io/HOML/images/exemplar-decision-tree.png" width="100%" style="display: block; margin: auto;" /> -- - Decision trees are __very easy to explain__ to non-statisticians. - Easy to visualize and thus easy to interpret __without assuming a parametric form__ --- ### Recursive splits: each _split / rule_ depends on previous split / rule _above_ it __Objective at each split__: find the __best__ variable to partition the data into one of two regions, `\(R_1\)` & `\(R_2\)`, to __minimize the error__ between the actual response, `\(y_i\)`, and the node's predicted constant, `\(c_i\)` -- - For regression we minimize the sum of squared errors (SSE): `$$S S E=\sum_{i \in R_{1}}\left(y_{i}-c_{1}\right)^{2}+\sum_{i \in R_{2}}\left(y_{i}-c_{2}\right)^{2}$$` -- - For classification trees we minimize the node's _impurity_ the __Gini index__ - where `\(p_k\)` is the proportion of observations in the node belonging to class `\(k\)` out of `\(K\)` total classes - want to minimize `\(Gini\)`: small values indicate a node has primarily one class (_is more pure_) `$$Gini = 1 - \sum_k^K p_k^2$$` -- Splits yield __locally optimal__ results, so we are NOT guaranteed to train a model that is globally optimal -- _How do we control the complexity of the tree?_ --- ## Tune the __maximum tree depth__ or __minimum node size__ <img src="https://bradleyboehmke.github.io/HOML/07-decision-trees_files/figure-html/dt-early-stopping-1.png" width="60%" style="display: block; margin: auto;" /> --- ## Prune the tree by tuning __cost complexity__ Can grow a very large complicated tree, and then __prune__ back to an optimal __subtree__ using a __cost complexity__ parameter `\(\alpha\)` (like `\(\lambda\)` for elastic net) - `\(\alpha\)` penalizes objective as a function of the number of __terminal nodes__ - e.g., we want to minimize `\(SSE + \alpha \cdot (\# \text{ of terminal nodes})\)` <img src="https://bradleyboehmke.github.io/HOML/07-decision-trees_files/figure-html/pruned-tree-1.png" width="80%" style="display: block; margin: auto;" /> --- ## Example data: MLB 2022 batting statistics Downloaded MLB 2022 batting statistics leaderboard from [Fangraphs](https://www.fangraphs.com/leaders.aspx?pos=all&stats=bat&lg=all&qual=y&type=8&season=2022&month=0&season1=2022&ind=0) ```r library(tidyverse) mlb_data <- read_csv("http://www.stat.cmu.edu/cmsac/sure/2022/materials/data/sports/fg_batting_2022.csv") %>% janitor::clean_names() %>% mutate_at(vars(bb_percent:k_percent), parse_number) head(mlb_data) ``` ``` ## # A tibble: 6 x 23 ## name team g pa hr r rbi sb bb_percent k_percent iso babip avg ## <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 Rafae… BOS 85 374 22 62 55 2 6.7 17.6 0.28 0.353 0.327 ## 2 Aaron… NYY 88 385 33 72 69 8 11.4 26 0.337 0.296 0.281 ## 3 Nolan… STL 88 370 18 41 59 1 8.9 13 0.233 0.295 0.293 ## 4 Manny… SDP 82 349 15 56 51 7 10.6 18.6 0.213 0.346 0.306 ## 5 Paul … STL 90 391 20 64 70 5 12 21.2 0.26 0.388 0.33 ## 6 Jose … CLE 87 375 19 54 75 13 10.7 9.9 0.288 0.275 0.288 ## # … with 10 more variables: obp <dbl>, slg <dbl>, w_oba <dbl>, xw_oba <dbl>, w_rc <dbl>, ## # bs_r <dbl>, off <dbl>, def <dbl>, war <dbl>, playerid <dbl> ``` --- ## Regression tree example with the [`rpart` package](https://cran.r-project.org/web/packages/rpart/vignettes/longintro.pdf) ```r library(rpart) init_mlb_tree <- rpart(formula = w_oba ~ bb_percent + k_percent + iso, data = mlb_data, method = "anova") init_mlb_tree ``` ``` ## n= 157 ## ## node), split, n, deviance, yval ## * denotes terminal node ## ## 1) root 157 0.215948200 0.3291338 ## 2) iso< 0.2055 123 0.113126200 0.3175691 ## 4) iso< 0.1035 16 0.016633000 0.2837500 * ## 5) iso>=0.1035 107 0.075457050 0.3226262 ## 10) bb_percent< 8.75 65 0.039689380 0.3146154 ## 20) k_percent>=27.15 9 0.001585556 0.2902222 * ## 21) k_percent< 27.15 56 0.031887930 0.3185357 ## 42) iso< 0.152 27 0.010937850 0.3089259 * ## 43) iso>=0.152 29 0.016135240 0.3274828 ## 86) k_percent>=21.85 17 0.008568235 0.3194706 * ## 87) k_percent< 21.85 12 0.004929667 0.3388333 * ## 11) bb_percent>=8.75 42 0.025140980 0.3350238 ## 22) k_percent>=23.45 11 0.002378909 0.3129091 * ## 23) k_percent< 23.45 31 0.015473480 0.3428710 ## 46) iso< 0.159 15 0.006778000 0.3320000 * ## 47) iso>=0.159 16 0.005260937 0.3530625 * ## 3) iso>=0.2055 34 0.026860970 0.3709706 ## 6) iso< 0.2595 23 0.009236609 0.3608696 * ## 7) iso>=0.2595 11 0.010370910 0.3920909 * ``` --- ## Display the tree with [`rpart.plot`](plhttp://www.milbo.org/rpart-plot/) .pull-left[ ```r library(rpart.plot) rpart.plot(init_mlb_tree) ``` <img src="19-Trees_files/figure-html/plot-tree-1.png" width="504" style="display: block; margin: auto;" /> ] .pull-right[ - `rpart()` runs 10-fold CV to tune `\(\alpha\)` for pruning - Selects # terminal nodes via 1 SE rule ```r plotcp(init_mlb_tree) ``` <img src="19-Trees_files/figure-html/plot-complexity-1.png" width="504" style="display: block; margin: auto;" /> ] --- ## What about the full tree? (check out `rpart.control`) .pull-left[ ```r full_mlb_tree <- rpart(formula = w_oba ~ bb_percent + k_percent + iso, data = mlb_data, method = "anova", control = list(cp = 0, xval = 10)) rpart.plot(full_mlb_tree) ``` <img src="19-Trees_files/figure-html/plot-full-tree-1.png" width="504" style="display: block; margin: auto;" /> ] .pull-right[ ```r plotcp(full_mlb_tree) ``` <img src="19-Trees_files/figure-html/plot-full-complexity-1.png" width="504" style="display: block; margin: auto;" /> ] --- ## Train with `caret` ```r library(caret) caret_mlb_tree <- train(w_oba ~ bb_percent + k_percent + iso + avg + obp + slg + war, data = mlb_data, method = "rpart", trControl = trainControl(method = "cv", number = 10), tuneLength = 20) ggplot(caret_mlb_tree) + theme_bw() ``` <img src="19-Trees_files/figure-html/caret-tree-1.png" width="504" style="display: block; margin: auto;" /> --- ## Display the final model ```r rpart.plot(caret_mlb_tree$finalModel) ``` <img src="19-Trees_files/figure-html/unnamed-chunk-6-1.png" width="504" style="display: block; margin: auto;" /> --- ## Summarizing variables in tree-based models .pull-left[ __Variable importance__ - based on reduction in SSE (_notice anything odd?_) ```r library(vip) vip(caret_mlb_tree, geom = "point") + theme_bw() ``` <img src="19-Trees_files/figure-html/var-imp-1.png" width="504" style="display: block; margin: auto;" /> ] .pull-right[ - Summarize single variable's relationship with __partial dependence plot__ ```r library(pdp) partial(caret_mlb_tree, pred.var = "obp") %>% autoplot() + theme_bw() ``` <img src="19-Trees_files/figure-html/pdp-1.png" width="504" style="display: block; margin: auto;" /> ] --- ## Classification: predicting MLB HRs Used the [`baseballr`](http://billpetti.github.io/baseballr/) package to scrape all batted-balls from 2022 season for EDA project: ```r library(tidyverse) batted_ball_data <- read_csv("http://www.stat.cmu.edu/cmsac/sure/2022/materials/data/sports/eda_projects/mlb_batted_balls_2022.csv") %>% mutate(is_hr = as.numeric(events == "home_run")) %>% filter(!is.na(launch_angle), !is.na(launch_speed), !is.na(is_hr)) table(batted_ball_data$is_hr) ``` ``` ## ## 0 1 ## 6702 333 ``` --- ## Predict HRs with launch angle and exit velocity? .pull-left[ ```r batted_ball_data %>% ggplot(aes(x = launch_speed, y = launch_angle, color = as.factor(is_hr))) + geom_point(alpha = 0.5) + ggthemes::scale_color_colorblind(labels = c("No", "Yes")) + labs(x = "Exit velocity", y = "Launch angle", color = "HR?") + theme_bw() + theme(legend.position = "bottom") ``` - HRs are relatively rare and confined to one area of this plot ] .pull-right[ <img src="19-Trees_files/figure-html/unnamed-chunk-7-1.png" width="504" /> ] --- ## Train with `caret` ```r library(caret) caret_hr_tree <- train(as.factor(is_hr) ~ launch_speed + launch_angle, data = batted_ball_data, method = "rpart", trControl = trainControl(method = "cv", number = 10), tuneLength = 20) ggplot(caret_hr_tree) + theme_bw() ``` <img src="19-Trees_files/figure-html/hr-tree-1.png" width="504" style="display: block; margin: auto;" /> --- ## Display the final model ```r rpart.plot(caret_hr_tree$finalModel) ``` <img src="19-Trees_files/figure-html/unnamed-chunk-8-1.png" width="504" style="display: block; margin: auto;" />