class: center, middle, inverse, title-slide .title[ # Tune better models ] .author[ ### INFO 5940
Cornell University ] --- class: inverse, middle # Decision trees --- ## Decision Trees To predict the outcome of a new data point: - Uses rules learned from splits - Each split maximizes information gain --- <img src="https://media.giphy.com/media/gj4ZruUQUnpug/source.gif" width="50%" style="display: block; margin: auto;" /> --- <img src="index_files/figure-html/unnamed-chunk-5-1.png" width="80%" style="display: block; margin: auto;" /> --- <img src="index_files/figure-html/unnamed-chunk-6-1.png" width="80%" style="display: block; margin: auto;" /> --- ## Quiz How do we assess predictions here? -- Root Mean Squared Error (RMSE) --- <img src="index_files/figure-html/rt-test-resid-1.png" width="80%" style="display: block; margin: auto;" /> --- class: middle, center <img src="https://raw.githubusercontent.com/EmilHvitfeldt/blog/master/static/blog/2019-08-09-authorship-classification-with-tidymodels-and-textrecipes_files/figure-html/unnamed-chunk-18-1.png" width="70%" style="display: block; margin: auto;" /> https://www.emilhvitfeldt.com/post/2019-08-09-authorship-classification-with-tidymodels-and-textrecipes/ --- class: middle, center <img src="https://www.kaylinpavlik.com/content/images/2019/12/dt-1.png" width="50%" style="display: block; margin: auto;" /> https://www.kaylinpavlik.com/classifying-songs-genres/ --- class: middle, center <img src="https://a3.typepad.com/6a0105360ba1c6970c01b7c95c61fb970b-pi" width="40%" style="display: block; margin: auto;" /> .footnote[[tweetbotornot2](https://github.com/mkearney/tweetbotornot2)] --- <img src="images/guess_the_animal.gif" width="80%" style="display: block; margin: auto;" /> --- ## What makes a good guesser? -- High information gain per question (can it fly?) -- Clear features (feathers vs. is it "small"?) -- Order matters --- <img src="images/aus-standard-animals.png" width="90%" style="display: block; margin: auto;" /> .footnote[[Australian Computing Academy](https://aca.edu.au/resources/decision-trees-classifying-animals/)] --- <img src="images/aus-standard-tree.png" width="90%" style="display: block; margin: auto;" /> .footnote[[Australian Computing Academy](https://aca.edu.au/resources/decision-trees-classifying-animals/)] --- <img src="images/annotated-tree/annotated-tree.001.png" width="90%" style="display: block; margin: auto;" /> --- <img src="images/annotated-tree/annotated-tree.002.png" width="90%" style="display: block; margin: auto;" /> --- <img src="images/annotated-tree/annotated-tree.003.png" width="90%" style="display: block; margin: auto;" /> --- <img src="images/annotated-tree/annotated-tree.004.png" width="90%" style="display: block; margin: auto;" /> --- <img src="images/annotated-tree/annotated-tree.005.png" width="90%" style="display: block; margin: auto;" /> --- ## To specify a model with `parsnip` 1\. Pick a .display[model] + .display[engine] 2\. Set the .display[mode] (if needed) --- ## To specify a decision tree model with `parsnip` ```r tree_mod <- decision_tree(engine = "rpart") %>% set_mode("classification") ``` --- class: middle, center <img src="index_files/figure-html/bechdel-tree-01-1.png" width="60%" style="display: block; margin: auto;" /> ``` ## nn test Fai Pas cover ## 2 Fail [.81 .19] when imdb_rating >= 8.1 7% ## 6 Fail [.57 .43] when imdb_rating is 6.4 to 8.1 63% ## 7 Pass [.46 .54] when imdb_rating < 6.4 30% ``` --- .pull-left[ <img src="index_files/figure-html/unnamed-chunk-22-1.png" width="80%" style="display: block; margin: auto;" /> ] .pull-right[ <img src="index_files/figure-html/unnamed-chunk-23-1.png" width="80%" style="display: block; margin: auto;" /> ] --- class: inverse ## ⏱ Your turn 1 Here is our very-vanilla parsnip model specification for a decision tree (also in your qmd)... ```r tree_mod <- decision_tree(engine = "rpart") %>% set_mode("classification") ``` And a workflow: ```r tree_wf <- workflow() %>% add_formula(test ~ .) %>% add_model(tree_mod) ``` For decision trees, no recipe really required 🎉 --- class: inverse ## ⏱ Your turn 1 Fill in the blanks to return the accuracy and ROC AUC for this model using 10-fold cross-validation.
02
:
00
--- ```r set.seed(100) tree_wf %>% fit_resamples(resamples = bechdel_folds) %>% collect_metrics() ## # A tibble: 2 × 6 ## .metric .estimator mean n std_err .config ## <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 accuracy binary 0.602 9 0.0177 Preprocessor1_Mod… ## 2 roc_auc binary 0.629 9 0.0159 Preprocessor1_Mod… ``` --- ## `args()` Print the arguments for a **parsnip** model specification. ```r args(decision_tree) ## function (mode = "unknown", engine = "rpart", cost_complexity = NULL, ## tree_depth = NULL, min_n = NULL) ## NULL ``` --- ## `decision_tree()` Specifies a decision tree model ```r decision_tree(engine = "rpart", tree_depth = 30, min_n = 20, cost_complexity = .01) ``` -- *either* mode works! --- ## `decision_tree()` Specifies a decision tree model ```r decision_tree( engine = "rpart", # default computational engine tree_depth = 30, # max tree depth min_n = 20, # smallest node allowed cost_complexity = .01 # 0 > cp > 0.1 ) ``` --- ## `set_args()` Change the arguments for a **parsnip** model specification. ```r _mod %>% set_args(tree_depth = 3) ``` --- ```r decision_tree(engine = "rpart") %>% set_mode("classification") %>% * set_args(tree_depth = 3) ## Decision Tree Model Specification (classification) ## ## Main Arguments: ## tree_depth = 3 ## ## Computational engine: rpart ``` --- ```r *decision_tree(engine = "rpart", tree_depth = 3) %>% set_mode("classification") ## Decision Tree Model Specification (classification) ## ## Main Arguments: ## tree_depth = 3 ## ## Computational engine: rpart ``` --- ## `tree_depth` Cap the maximum tree depth. A method to stop the tree early. Used to prevent **overfitting.** ```r tree_mod %>% set_args(tree_depth = 30) ``` --- <img src="index_files/figure-html/unnamed-chunk-36-1.png" width="80%" style="display: block; margin: auto;" /> --- <img src="index_files/figure-html/unnamed-chunk-37-1.png" width="80%" style="display: block; margin: auto;" /> --- ## `min_n` Set minimum `n` to split at any node. Another early stopping method. Used to prevent overfitting. ```r tree_mod %>% set_args(min_n = 20) ``` --- ## Quiz What value of `min_n` would lead to the *most overfit* tree? -- `min_n` = 1 --- class: middle, center, frame ## Recap: early stopping | `parsnip` arg | `rpart` arg | default | overfit? | |---------------|-------------|:-------:|:--------:| | `tree_depth` | `maxdepth` | 30 |⬆️| | `min_n` | `minsplit` | 20 |⬇️| --- ## `cost_complexity` Adds a cost or penalty to error rates of more complex trees. A way to prune a tree. Used to prevent overfitting. ```r tree_mod %>% set_args(cost_complexity = .01) ``` -- Closer to zero ➡️ larger trees. Higher penalty ➡️ smaller trees. --- <img src="index_files/figure-html/unnamed-chunk-40-1.png" width="80%" style="display: block; margin: auto;" /> --- name: bonsai background-image: url(images/kari-shea-AVqh83jStMA-unsplash.jpg) background-position: left background-size: contain class: middle --- template: bonsai .pull-right[ ## Consider the bonsai 1. Small pot 1. Strong shears ] --- template: bonsai .pull-right[ ## Consider the bonsai 1. ~~Small pot~~ .display[Early stopping] 1. ~~Strong shears~~ .display[Pruning] ] --- class: middle, center, frame ## Recap: early stopping & pruning | `parsnip` arg | `rpart` arg | default | overfit? | |---------------|-------------|:-------:|:--------:| | `tree_depth` | `maxdepth` | 30 |⬆️| | `min_n` | `minsplit` | 20 |⬇️| | `cost_complexity` | `cp` | .01 |⬇️| --- class: middle, center <table> <thead> <tr> <th style="text-align:left;"> engine </th> <th style="text-align:left;"> parsnip </th> <th style="text-align:left;"> original </th> </tr> </thead> <tbody> <tr> <td style="text-align:left;"> rpart </td> <td style="text-align:left;"> tree_depth </td> <td style="text-align:left;"> maxdepth </td> </tr> <tr> <td style="text-align:left;"> rpart </td> <td style="text-align:left;"> min_n </td> <td style="text-align:left;"> minsplit </td> </tr> <tr> <td style="text-align:left;"> rpart </td> <td style="text-align:left;"> cost_complexity </td> <td style="text-align:left;"> cp </td> </tr> </tbody> </table> <https://rdrr.io/cran/rpart/man/rpart.control.html> --- <img src="index_files/figure-html/unnamed-chunk-42-1.png" width="80%" style="display: block; margin: auto;" /> --- <img src="index_files/figure-html/unnamed-chunk-43-1.png" width="80%" style="display: block; margin: auto;" /> --- <img src="index_files/figure-html/unnamed-chunk-44-1.png" width="80%" style="display: block; margin: auto;" /> --- <img src="index_files/figure-html/unnamed-chunk-45-1.png" width="80%" style="display: block; margin: auto;" /> --- <img src="index_files/figure-html/unnamed-chunk-46-1.png" width="80%" style="display: block; margin: auto;" /> --- <img src="index_files/figure-html/unnamed-chunk-47-1.png" width="80%" style="display: block; margin: auto;" /> --- <img src="index_files/figure-html/unnamed-chunk-48-1.png" width="80%" style="display: block; margin: auto;" /> --- ## Axiom There is an inverse relationship between model .display[accuracy] and model .display[interpretability]. --- class: inverse, middle ## Random forests --- ## `rand_forest()` Specifies a random forest model ```r rand_forest(mtry = 4, trees = 500, min_n = 1) ``` -- *either* mode works! --- ## `rand_forest()` Specifies a random forest model ```r rand_forest( engine = "ranger", # default computational engine mtry = 4, # predictors seen at each node trees = 500, # trees per forest min_n = 1 # smallest node allowed ) ``` --- class: inverse ## ⏱ Your turn 2 Create a new parsnip model called `rf_mod`, which will learn an ensemble of classification trees from our training data using the **ranger** engine. Update your `tree_wf` with this new model. Fit your workflow with 10-fold cross-validation and compare the ROC AUC of the random forest to your single decision tree model --- which predicts the test set better? *Hint: you'll need https://www.tidymodels.org/find/parsnip/*
04
:
00
--- ```r rf_mod <- rand_forest(engine = "ranger") %>% set_mode("classification") rf_wf <- tree_wf %>% update_model(rf_mod) set.seed(100) rf_wf %>% fit_resamples(resamples = bechdel_folds) %>% collect_metrics() ## # A tibble: 2 × 6 ## .metric .estimator mean n std_err .config ## <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 accuracy binary 0.615 10 0.0142 Preprocessor1_Mod… ## 2 roc_auc binary 0.636 10 0.0172 Preprocessor1_Mod… ``` --- ## `mtry` The number of predictors that will be randomly sampled at each split when creating the tree models. ```r rand_forest(mtry = 5) ``` **ranger** default = `floor(sqrt(num_predictors))` --- .pull-left[ ### Single decision tree ```r tree_mod <- decision_tree(engine = "rpart") %>% set_mode("classification") tree_wf <- workflow() %>% add_formula(test ~ .) %>% add_model(tree_mod) set.seed(100) tree_res <- tree_wf %>% fit_resamples( resamples = bechdel_folds, control = control_resamples(save_pred = TRUE) ) tree_res %>% collect_metrics() ## # A tibble: 2 × 6 ## .metric .estimator mean n std_err .config ## <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 accuracy binary 0.602 9 0.0177 Preprocessor1_Mod… ## 2 roc_auc binary 0.629 9 0.0159 Preprocessor1_Mod… ``` ] -- .pull-right[ ### A random forest of trees ```r rf_mod <- rand_forest(engine = "ranger") %>% set_mode("classification") rf_wf <- tree_wf %>% update_model(rf_mod) set.seed(100) rf_res <- rf_wf %>% fit_resamples( resamples = bechdel_folds, control = control_resamples(save_pred = TRUE) ) rf_res %>% collect_metrics() ## # A tibble: 2 × 6 ## .metric .estimator mean n std_err .config ## <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 accuracy binary 0.615 10 0.0142 Preprocessor1_Mod… ## 2 roc_auc binary 0.636 10 0.0172 Preprocessor1_Mod… ``` ] --- .pull-left[ ### Single decision tree <img src="index_files/figure-html/unnamed-chunk-56-1.png" width="80%" style="display: block; margin: auto;" /> ] -- .pull-right[ ### A random forest of trees <img src="index_files/figure-html/unnamed-chunk-57-1.png" width="80%" style="display: block; margin: auto;" /> ] --- class: inverse ## ⏱ Your turn 3 Challenge: Fit 3 more random forest models, each using 3, 5, and 8 variables at each split. Update your `rf_wf` with each new model. Which value maximizes the area under the ROC curve?
03
:
00
--- ```r rf3_mod <- rf_mod %>% * set_args(mtry = 3) rf5_mod <- rf_mod %>% * set_args(mtry = 5) rf8_mod <- rf_mod %>% * set_args(mtry = 8) ``` --- ```r rf3_wf <- rf_wf %>% update_model(rf3_mod) set.seed(100) rf3_wf %>% fit_resamples(resamples = bechdel_folds) %>% collect_metrics() ## # A tibble: 2 × 6 ## .metric .estimator mean n std_err .config ## <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 accuracy binary 0.611 10 0.0127 Preprocessor1_Mod… ## 2 roc_auc binary 0.634 10 0.0160 Preprocessor1_Mod… ``` --- ```r rf5_wf <- rf_wf %>% update_model(rf5_mod) set.seed(100) rf5_wf %>% fit_resamples(resamples = bechdel_folds) %>% collect_metrics() ## # A tibble: 2 × 6 ## .metric .estimator mean n std_err .config ## <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 accuracy binary 0.603 10 0.0180 Preprocessor1_Mod… ## 2 roc_auc binary 0.628 10 0.0173 Preprocessor1_Mod… ``` --- ```r rf8_wf <- rf_wf %>% update_model(rf8_mod) set.seed(100) rf8_wf %>% fit_resamples(resamples = bechdel_folds) %>% collect_metrics() ## # A tibble: 2 × 6 ## .metric .estimator mean n std_err .config ## <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 accuracy binary 0.606 10 0.0146 Preprocessor1_Mod… ## 2 roc_auc binary 0.627 10 0.0189 Preprocessor1_Mod… ``` --- class: inverse, middle # 🎷 Fitting and tuning models with `tune` --- class: middle, center, frame ## `tune` Functions for fitting and tuning models <https://tune.tidymodels.org> <iframe src="https://tune.tidymodels.org" width="100%" height="400px" data-external="1"></iframe> --- ## `tune()` A placeholder for hyper-parameters to be "tuned" ```r nearest_neighbor(neighbors = tune()) ``` --- ## `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. .pull-left[ ```r tune_grid( object, resamples, ..., grid = 10, metrics = NULL, control = control_grid() ) ``` ] --- ## `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. .pull-left[ ```r tune_grid( * object, resamples, ..., grid = 10, metrics = NULL, control = control_grid() ) ``` ] -- .pull-right[ One of: + A parsnip `model` object + A `workflow` ] --- ## `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. .pull-left[ ```r tune_grid( * object, * preprocessor, resamples, ..., grid = 10, metrics = NULL, control = control_grid() ) ``` ] .pull-right[ A `model` + `recipe` ] --- ## `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. .pull-left[ ```r tune_grid( object, resamples, ..., * grid = 10, metrics = NULL, control = control_grid() ) ``` ] .pull-right[ One of: + A positive integer. + A data frame of tuning combinations. ] --- ## `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. .pull-left[ ```r tune_grid( object, resamples, ..., * grid = 10, metrics = NULL, control = control_grid() ) ``` ] .pull-right[ Number of candidate parameter sets to be created automatically; `10` is the default. ] --- class: inverse ## ⏱ Your Turn 4 Here's our random forest model plus workflow to work with. ```r rf_mod <- rand_forest(engine = "ranger") %>% set_mode("classification") rf_wf <- workflow() %>% add_formula(test ~ .) %>% add_model(rf_mod) ``` --- class: inverse ## ⏱ Your Turn 4 Here is the output from `fit_resamples()`... ```r set.seed(100) # Important! rf_results <- rf_wf %>% fit_resamples(resamples = bechdel_folds) rf_results %>% collect_metrics() ## # A tibble: 2 × 6 ## .metric .estimator mean n std_err .config ## <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 accuracy binary 0.615 10 0.0142 Preprocessor1_Mod… ## 2 roc_auc binary 0.636 10 0.0172 Preprocessor1_Mod… ``` --- class: inverse ## ⏱ Your Turn 4 Edit the random forest model to tune the `mtry` and `min_n` hyperparameters. Update your workflow to use the tuned model. Then use `tune_grid()` to find the best combination of hyper-parameters to maximize `roc_auc`; let tune set up the grid for you. How does it compare to the average ROC AUC across folds from `fit_resamples()`?
05
:
00
--- ```r rf_tuner <- rand_forest( engine = "ranger", mtry = tune(), min_n = tune() ) %>% set_mode("classification") rf_wf <- rf_wf %>% update_model(rf_tuner) set.seed(100) # Important! rf_results <- rf_wf %>% tune_grid(resamples = bechdel_folds) ``` --- ```r rf_results %>% collect_metrics() ## # A tibble: 20 × 8 ## mtry min_n .metric .estim…¹ mean n std_err .config ## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 4 28 accuracy binary 0.621 10 0.0129 Prepro… ## 2 4 28 roc_auc binary 0.639 10 0.0165 Prepro… ## 3 3 36 accuracy binary 0.621 10 0.0138 Prepro… ## 4 3 36 roc_auc binary 0.641 10 0.0178 Prepro… ## 5 5 21 accuracy binary 0.630 10 0.0149 Prepro… ## 6 5 21 roc_auc binary 0.636 10 0.0172 Prepro… ## 7 1 13 accuracy binary 0.613 10 0.0124 Prepro… ## 8 1 13 roc_auc binary 0.639 10 0.0164 Prepro… ## 9 5 8 accuracy binary 0.610 10 0.0153 Prepro… ## 10 5 8 roc_auc binary 0.629 10 0.0166 Prepro… ## # … with 10 more rows, and abbreviated variable name ## # ¹.estimator ``` --- ```r rf_results %>% collect_metrics(summarize = FALSE) ## # A tibble: 200 × 7 ## id mtry min_n .metric .estimator .estimate .config ## <chr> <int> <int> <chr> <chr> <dbl> <chr> ## 1 Fold01 4 28 accuracy binary 0.635 Preproc… ## 2 Fold01 4 28 roc_auc binary 0.543 Preproc… ## 3 Fold02 4 28 accuracy binary 0.635 Preproc… ## 4 Fold02 4 28 roc_auc binary 0.704 Preproc… ## 5 Fold03 4 28 accuracy binary 0.635 Preproc… ## 6 Fold03 4 28 roc_auc binary 0.657 Preproc… ## 7 Fold04 4 28 accuracy binary 0.571 Preproc… ## 8 Fold04 4 28 roc_auc binary 0.636 Preproc… ## 9 Fold05 4 28 accuracy binary 0.568 Preproc… ## 10 Fold05 4 28 roc_auc binary 0.655 Preproc… ## # … with 190 more rows ``` --- ## `tune_grid()` A version of `fit_resamples()` that performs a grid search for the best combination of tuned hyper-parameters. .pull-left[ ```r tune_grid( object, resamples, ..., * grid = df, metrics = NULL, control = control_grid() ) ``` ] .pull-right[ A data frame of tuning combinations. ] --- ## `expand_grid()` Takes one or more vectors, and returns a data frame holding all combinations of their values. ```r expand_grid(mtry = c(1, 5), min_n = 1:3) ## # A tibble: 6 × 2 ## mtry min_n ## <dbl> <int> ## 1 1 1 ## 2 1 2 ## 3 1 3 ## 4 5 1 ## 5 5 2 ## 6 5 3 ``` -- .footnote[tidyr package; see also base `expand.grid()`] --- ## `show_best()` Shows the .display[n] most optimum combinations of hyper-parameters ```r rf_results %>% show_best(metric = "roc_auc", n = 5) ``` --- ``` ## # A tibble: 5 × 8 ## mtry min_n .metric .estimator mean n std_err .config ## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 3 36 roc_auc binary 0.641 10 0.0178 Prepro… ## 2 1 13 roc_auc binary 0.639 10 0.0164 Prepro… ## 3 4 28 roc_auc binary 0.639 10 0.0165 Prepro… ## 4 3 20 roc_auc binary 0.638 10 0.0173 Prepro… ## 5 2 15 roc_auc binary 0.637 10 0.0159 Prepro… ``` --- ## `autoplot()` Quickly visualize tuning results ```r rf_results %>% autoplot() ``` --- <img src="index_files/figure-html/unnamed-chunk-81-1.png" width="80%" style="display: block; margin: auto;" /> --- ## `select_best()` Shows the .display[top] combination of hyper-parameters. ```r bechdel_best <- rf_results %>% select_best(metric = "roc_auc") bechdel_best ``` --- ``` ## # A tibble: 1 × 3 ## mtry min_n .config ## <int> <int> <chr> ## 1 3 36 Preprocessor1_Model02 ``` --- ## `finalize_workflow()` Replaces `tune()` placeholders in a model/recipe/workflow with a set of hyper-parameter values. ```r last_rf_workflow <- rf_wf %>% finalize_workflow(bechdel_best) ``` --- background-image: url(images/diamonds.jpg) background-size: contain background-position: left class: middle, center background-color: #f5f5f5 .pull-right[ ## We are ready to touch the jewels... ## The .display[testing set]! ] --- ## `last_fit()` ```r last_rf_fit <- last_rf_workflow %>% last_fit(split = bechdel_split) ``` --- ```r last_rf_fit ## # Resampling results ## # Manual resampling ## # A tibble: 1 × 6 ## splits id .metrics .notes .predi…¹ ## <list> <chr> <list> <list> <list> ## 1 <split [1253/141]> train/test … <tibble> <tibble> <tibble> ## # … with 1 more variable: .workflow <list>, and abbreviated ## # variable name ¹.predictions ``` --- class: inverse ## ⏱ Your Turn 5 Use `select_best()`, `finalize_workflow()`, and `last_fit()` to take the best combination of hyper-parameters from `rf_results` and use them to predict the test set. How does our actual test ROC AUC compare to our cross-validated estimate?
05
:
00
--- ```r bechdel_best <- rf_results %>% select_best(metric = "roc_auc") last_rf_workflow <- rf_wf %>% finalize_workflow(bechdel_best) last_rf_fit <- last_rf_workflow %>% last_fit(split = bechdel_split) last_rf_fit %>% collect_metrics() ``` --- ## Final metrics ```r last_rf_fit %>% collect_metrics() ## # A tibble: 2 × 4 ## .metric .estimator .estimate .config ## <chr> <chr> <dbl> <chr> ## 1 accuracy binary 0.681 Preprocessor1_Model1 ## 2 roc_auc binary 0.723 Preprocessor1_Model1 ``` --- ## Final test predictions ```r last_rf_fit %>% collect_predictions() ## # A tibble: 141 × 7 ## id .pred…¹ .pred…² .row .pred…³ test .config ## <chr> <dbl> <dbl> <int> <fct> <fct> <chr> ## 1 train/test s… 0.518 0.482 15 Fail Fail Prepro… ## 2 train/test s… 0.649 0.351 20 Fail Pass Prepro… ## 3 train/test s… 0.644 0.356 34 Fail Pass Prepro… ## 4 train/test s… 0.594 0.406 40 Fail Fail Prepro… ## 5 train/test s… 0.737 0.263 42 Fail Fail Prepro… ## 6 train/test s… 0.560 0.440 65 Fail Fail Prepro… ## 7 train/test s… 0.441 0.559 83 Pass Fail Prepro… ## 8 train/test s… 0.480 0.520 102 Pass Fail Prepro… ## 9 train/test s… 0.476 0.524 125 Pass Pass Prepro… ## 10 train/test s… 0.599 0.401 131 Fail Pass Prepro… ## # … with 131 more rows, and abbreviated variable names ## # ¹.pred_Fail, ².pred_Pass, ³.pred_class ``` --- ```r roc_values <- last_rf_fit %>% collect_predictions() %>% roc_curve(truth = test, estimate = .pred_Fail) autoplot(roc_values) ``` <img src="index_files/figure-html/unnamed-chunk-91-1.png" width="70%" style="display: block; margin: auto;" /> --- class: inverse, middle # The entire process --- ## The set-up ```r set.seed(100) # Important! # holdout method bechdel_split <- initial_split(bechdel, strata = test, prop = .9) bechdel_train <- training(bechdel_split) bechdel_test <- testing(bechdel_split) # add cross-validation set.seed(100) bechdel_folds <- vfold_cv(bechdel_train, v = 10, strata = test) ``` --- ## The tune-up ```r # here comes the actual ML bits… # pick model to tune rf_tuner <- rand_forest( engine = "ranger", mtry = tune(), min_n = tune() ) %>% set_mode("classification") rf_wf <- workflow() %>% add_formula(test ~ .) %>% add_model(rf_tuner) rf_results <- rf_wf %>% tune_grid( resamples = bechdel_folds, control = control_grid(save_pred = TRUE) ) ``` --- ## Quick check-in... ```r rf_results %>% collect_predictions() %>% group_by(.config, mtry, min_n) %>% summarize(folds = n_distinct(id)) ## # A tibble: 10 × 4 ## # Groups: .config, mtry [10] ## .config mtry min_n folds ## <chr> <int> <int> <int> ## 1 Preprocessor1_Model01 8 29 10 ## 2 Preprocessor1_Model02 3 7 10 ## 3 Preprocessor1_Model03 6 38 10 ## 4 Preprocessor1_Model04 4 15 10 ## 5 Preprocessor1_Model05 1 6 10 ## 6 Preprocessor1_Model06 2 24 10 ## 7 Preprocessor1_Model07 7 11 10 ## 8 Preprocessor1_Model08 3 28 10 ## 9 Preprocessor1_Model09 5 33 10 ## 10 Preprocessor1_Model10 6 19 10 ``` --- ## The match up! .pull-left[ ```r show_best(rf_results, metric = "roc_auc", n = 5) ## # A tibble: 5 × 8 ## mtry min_n .metric .estimator mean n std_err .config ## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 2 24 roc_auc binary 0.640 10 0.0160 Prepro… ## 2 1 6 roc_auc binary 0.637 10 0.0174 Prepro… ## 3 4 15 roc_auc binary 0.636 10 0.0171 Prepro… ## 4 3 28 roc_auc binary 0.636 10 0.0162 Prepro… ## 5 5 33 roc_auc binary 0.635 10 0.0167 Prepro… # pick final model workflow bechdel_best <- rf_results %>% select_best(metric = "roc_auc") bechdel_best ## # A tibble: 1 × 3 ## mtry min_n .config ## <int> <int> <chr> ## 1 2 24 Preprocessor1_Model06 ``` ] .pull-right[ <img src="index_files/figure-html/unnamed-chunk-96-1.png" width="80%" style="display: block; margin: auto;" /> ] --- ## The wrap-up .pull-left[ ```r last_rf_workflow <- rf_wf %>% finalize_workflow(bechdel_best) last_rf_workflow ## ══ Workflow ════════════════════════════════════════════════════════════════════ ## Preprocessor: Formula ## Model: rand_forest() ## ## ── Preprocessor ──────────────────────────────────────────────────────────────── ## test ~ . ## ## ── Model ─────────────────────────────────────────────────────────────────────── ## Random Forest Model Specification (classification) ## ## Main Arguments: ## mtry = 2 ## min_n = 24 ## ## Computational engine: ranger ``` ] -- .pull-right[ ```r # train + test final model last_rf_fit <- last_rf_workflow %>% last_fit(split = bechdel_split) # explore final model last_rf_fit %>% collect_metrics() ## # A tibble: 2 × 4 ## .metric .estimator .estimate .config ## <chr> <chr> <dbl> <chr> ## 1 accuracy binary 0.688 Preprocessor1_Model1 ## 2 roc_auc binary 0.717 Preprocessor1_Model1 last_rf_fit %>% collect_predictions() %>% roc_curve(truth = test, estimate = .pred_Fail) %>% autoplot() ``` ]