class: center, middle, inverse, title-slide # Machine learning ## Decision trees ### July 12th, 2021 --- ## What is Machine Learning? -- The short version: - Machine learning (ML) is a subset of statistical learning that focuses on prediction -- The longer version: - ML focuses on constructing data-driven algorithms that *learn* the mapping between predictor variables and response variable(s). - We do not assume a parametric form for the mapping *a priori*, even if technically one can write one down *a posteriori* (e.g., by translating a tree model to a indicator-variable mathematical expression) - e.g., linear regression is NOT considered a ML algorithm since we can write down the linear equation ahead of time - e.g., random forests are considered an ML algorithm since we have what the trees will look like in advance --- ## Which algorithm is best? __That's not the right question to ask.__ (And the answer is *not* deep learning. Because if the underlying relationship between your predictors and your response is truly linear, *you do not need to apply deep learning*! Just do linear regression. Really. It's OK.) -- The right question is ask is: __why should I try different algorithms?__ -- The answer to that is that without superhuman powers, you cannot visualize the distribution of predictor variables in their native space. - Of course, you can visualize these data *in projection*, for instance when we perform EDA - And the performance of different algorithms will depend on how predictor data are distributed... --- ### Data geometry <img src="http://www.stat.cmu.edu/~pfreeman/data_geometry.png" width="70%" style="display: block; margin: auto;" /> - Two predictor variables with binary response variable: x's and o's - __LHS__: Linear boundaries that form rectangles will peform well in predicting response - __RHS__: Circular boundaries will perform better --- ## Decision trees Decision trees partition training data into __homogenous nodes / subgroups__ with similar response values -- The subgroups are found __recursively using binary partitions__ - i.e. asking a series of yes-no questions about the predictor variables We stop splitting the tree once a __stopping criteria__ has been reached (e.g. maximum depth allowed) -- For each subgroup / node predictions are made with: - Regression tree: __the average of the response values__ in the node - Classification tree: __the most popular class__ in the node -- Most popular approach is Leo Breiman's __C__lassification __A__nd __R__egression __T__ree (CART) algorithm --- ## Decision tree structure <img src="https://bradleyboehmke.github.io/HOML/images/decision-tree-terminology.png" width="100%" style="display: block; margin: auto;" /> --- ## Decision tree structure We make a prediction for an observation by __following its path along the tree__ <img src="https://bradleyboehmke.github.io/HOML/images/exemplar-decision-tree.png" width="100%" style="display: block; margin: auto;" /> -- - Decision trees are __very easy to explain__ to non-statisticians. - Easy to visualize and thus easy to interpret __without assuming a parametric form__ --- ### Recursive splits: each _split / rule_ depends on previous split / rule _above_ it __Objective at each split__: find the __best__ variable to partition the data into one of two regions, `\(R_1\)` & `\(R_2\)`, to __minimize the error__ between the actual response, `\(y_i\)`, and the node's predicted constant, `\(c_i\)` -- - For regression we minimize the sum of squared errors (SSE): `$$S S E=\sum_{i \in R_{1}}\left(y_{i}-c_{1}\right)^{2}+\sum_{i \in R_{2}}\left(y_{i}-c_{2}\right)^{2}$$` -- - For classification trees we minimize the node's _impurity_ the __Gini index__ - where `\(p_k\)` is the proportion of observations in the node belonging to class `\(k\)` out of `\(K\)` total classes - want to minimize `\(Gini\)`: small values indicate a node has primarily one class (_is more pure_) `$$Gini = 1 - \sum_k^K p_k^2$$` -- Splits yield __locally optimal__ results, so we are NOT guaranteed to train a model that is globally optimal -- _How do we control the complexity of the tree?_ --- ## Tune the __maximum tree depth__ or __minimum node size__ <img src="https://bradleyboehmke.github.io/HOML/07-decision-trees_files/figure-html/dt-early-stopping-1.png" width="60%" style="display: block; margin: auto;" /> --- ## Prune the tree by tuning __cost complexity__ Can grow a very large complicated tree, and then __prune__ back to an optimal __subtree__ using a __cost complexity__ parameter `\(\alpha\)` (like `\(\lambda\)` for elastic net) - `\(\alpha\)` penalizes objective as a function of the number of __terminal nodes__ - e.g., we want to minimize `\(SSE + \alpha \cdot (\# \text{ of terminal nodes})\)` <img src="https://bradleyboehmke.github.io/HOML/07-decision-trees_files/figure-html/pruned-tree-1.png" width="80%" style="display: block; margin: auto;" /> --- ## Example data: MLB 2021 batting statistics Downloaded MLB 2021 batting statistics leaderboard from [Fangraphs](https://www.fangraphs.com/leaders.aspx?pos=all&stats=bat&lg=all&qual=y&type=8&season=2019&month=0&season1=2021&ind=0) ```r library(tidyverse) mlb_data <- read_csv("http://www.stat.cmu.edu/cmsac/sure/2021/materials/data/fg_batting_2021.csv") %>% janitor::clean_names() %>% mutate_at(vars(bb_percent:k_percent), parse_number) head(mlb_data) ``` ``` ## # A tibble: 6 x 23 ## name team g pa hr r rbi sb bb_percent k_percent iso babip avg obp ## <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 Vladimi… TOR 82 354 27 66 69 2 14.4 17.2 0.336 0.346 0.336 0.438 ## 2 Fernand… SDP 68 288 27 66 58 18 12.5 28.1 0.395 0.333 0.302 0.385 ## 3 Carlos … HOU 79 347 16 61 52 0 13.5 17 0.231 0.324 0.298 0.398 ## 4 Marcus … TOR 82 372 21 63 54 10 8.9 23.9 0.256 0.329 0.286 0.349 ## 5 Ronald … ATL 78 342 23 67 51 16 13.2 24.3 0.313 0.306 0.278 0.386 ## 6 Shohei … LAA 82 322 31 60 67 12 11.2 28 0.418 0.29 0.277 0.363 ## # … with 9 more variables: slg <dbl>, w_oba <dbl>, xw_oba <dbl>, w_rc <dbl>, bs_r <dbl>, ## # off <dbl>, def <dbl>, war <dbl>, playerid <dbl> ``` --- ## Regression tree example with the [`rpart` package](https://cran.r-project.org/web/packages/rpart/vignettes/longintro.pdf) Revisit the modeling of `w_oba` from the KNN slides ```r library(rpart) init_mlb_tree <- rpart(formula = w_oba ~ bb_percent + k_percent + iso, data = mlb_data, method = "anova") init_mlb_tree ``` ``` ## n= 135 ## ## node), split, n, deviance, yval ## * denotes terminal node ## ## 1) root 135 0.203847700 0.3383259 ## 2) iso< 0.2 82 0.069028260 0.3188171 ## 4) iso< 0.1315 27 0.021396070 0.2981852 ## 8) bb_percent< 10.15 19 0.013706740 0.2894737 * ## 9) bb_percent>=10.15 8 0.002822875 0.3188750 * ## 5) iso>=0.1315 55 0.030496840 0.3289455 ## 10) k_percent>=15.15 47 0.020832550 0.3243404 ## 20) bb_percent< 11.45 37 0.012078700 0.3186216 ## 40) iso< 0.1585 13 0.002205692 0.3071538 * ## 41) iso>=0.1585 24 0.007237333 0.3248333 * ## 21) bb_percent>=11.45 10 0.003066500 0.3455000 * ## 11) k_percent< 15.15 8 0.002812000 0.3560000 * ## 3) iso>=0.2 53 0.055325250 0.3685094 ## 6) iso< 0.283 46 0.029749830 0.3607826 ## 12) bb_percent< 6.5 8 0.000972000 0.3335000 * ## 13) bb_percent>=6.5 38 0.021569470 0.3665263 ## 26) iso< 0.251 29 0.014496690 0.3601034 ## 52) k_percent>=22.65 14 0.005877500 0.3505000 * ## 53) k_percent< 22.65 15 0.006122933 0.3690667 * ## 27) iso>=0.251 9 0.002021556 0.3872222 * ## 7) iso>=0.283 7 0.004781429 0.4192857 * ``` --- ## Display the tree with [`rpart.plot`](plhttp://www.milbo.org/rpart-plot/) .pull-left[ ```r library(rpart.plot) rpart.plot(init_mlb_tree) ``` <img src="20-Trees-rf_files/figure-html/plot-tree-1.png" width="504" style="display: block; margin: auto;" /> ] .pull-right[ - `rpart()` runs 10-fold CV to tune `\(\alpha\)` for pruning - Selects # terminal nodes via 1 SE rule ```r plotcp(init_mlb_tree) ``` <img src="20-Trees-rf_files/figure-html/plot-complexity-1.png" width="504" style="display: block; margin: auto;" /> ] --- ## What about the full tree? (check out `rpart.control`) .pull-left[ ```r full_mlb_tree <- rpart(formula = w_oba ~ bb_percent + k_percent + iso, data = mlb_data, method = "anova", control = list(cp = 0, xval = 10)) rpart.plot(full_mlb_tree) ``` <img src="20-Trees-rf_files/figure-html/plot-full-tree-1.png" width="504" style="display: block; margin: auto;" /> ] .pull-right[ ```r plotcp(full_mlb_tree) ``` <img src="20-Trees-rf_files/figure-html/plot-full-complexity-1.png" width="504" style="display: block; margin: auto;" /> ] --- ## Train with `caret` ```r library(caret) caret_mlb_tree <- train(w_oba ~ bb_percent + k_percent + iso + avg + obp + slg + war, data = mlb_data, method = "rpart", trControl = trainControl(method = "cv", number = 10), tuneLength = 20) ggplot(caret_mlb_tree) + theme_bw() ``` <img src="20-Trees-rf_files/figure-html/caret-tree-1.png" width="504" style="display: block; margin: auto;" /> --- ## Display the final model ```r rpart.plot(caret_mlb_tree$finalModel) ``` <img src="20-Trees-rf_files/figure-html/unnamed-chunk-6-1.png" width="504" style="display: block; margin: auto;" /> --- ## Summarizing variables in tree-based models .pull-left[ __Variable importance__ - based on reduction in SSE (_notice anything odd?_) ```r library(vip) vip(caret_mlb_tree, geom = "point") + theme_bw() ``` <img src="20-Trees-rf_files/figure-html/var-imp-1.png" width="504" style="display: block; margin: auto;" /> ] .pull-right[ - Summarize single variable's relationship with __partial dependence plot__ ```r library(pdp) partial(caret_mlb_tree, pred.var = "obp") %>% autoplot() + theme_bw() ``` <img src="20-Trees-rf_files/figure-html/pdp-1.png" width="504" style="display: block; margin: auto;" /> ]