Machine Learning Workflow: Lasso Regression

Econ 425T

Author

Dr. Hua Zhou @ UCLA

Published

January 25, 2023

Display system information for reproducibility.

import IPython
print(IPython.sys_info())
{'commit_hash': 'add5877a4',
 'commit_source': 'installation',
 'default_encoding': 'utf-8',
 'ipython_path': '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/IPython',
 'ipython_version': '8.8.0',
 'os_name': 'posix',
 'platform': 'macOS-10.16-x86_64-i386-64bit',
 'sys_executable': '/Library/Frameworks/Python.framework/Versions/3.10/bin/python3',
 'sys_platform': 'darwin',
 'sys_version': '3.10.9 (v3.10.9:1dd9be6584, Dec  6 2022, 14:37:36) [Clang '
                '13.0.0 (clang-1300.0.29.30)]'}
sessionInfo()
R version 4.2.2 (2022-10-31)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Big Sur ... 10.16

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.9        here_1.0.1        lattice_0.20-45   png_0.1-8        
 [5] rprojroot_2.0.3   digest_0.6.30     grid_4.2.2        lifecycle_1.0.3  
 [9] jsonlite_1.8.4    magrittr_2.0.3    evaluate_0.18     rlang_1.0.6      
[13] stringi_1.7.8     cli_3.4.1         rstudioapi_0.14   Matrix_1.5-1     
[17] reticulate_1.26   vctrs_0.5.1       rmarkdown_2.18    tools_4.2.2      
[21] stringr_1.5.0     glue_1.6.2        htmlwidgets_1.6.0 xfun_0.35        
[25] yaml_2.3.6        fastmap_1.1.0     compiler_4.2.2    htmltools_0.5.4  
[29] knitr_1.41       

1 Overview

We illustrate the typical machine learning workflow for regression problems using the Hitters data set from R ISLR2 package. The steps are

  1. Initial splitting to test and non-test sets.

  2. Pre-processing of data (pipeline in Python, recipe in R).

  3. Choose a learner/method. Lasso in this example.

  4. Tune the hyper-parameter(s) (\(\lambda\) in this example) using \(K\)-fold cross-validation (CV) on the non-test data.

  5. Choose the best model by CV and refit it on the whole non-test data.

  6. Final prediction on the test data.

These steps completes the process of training and evaluating one machine learning method (lasso in this case). We repeat the same process for other learners, e.g., random forest or neural network, using the same test/non-test and CV split. The final report compares the learners based on CV and test errors.

2 Hitters data

A documentation of the Hitters data is here. The goal is to predict the salary (at opening of 1987 season) of MLB players from their performance metrics in the 1986-7 season.

# Load the pandas library
import pandas as pd
# Load numpy for array manipulation
import numpy as np
# Load seaborn plotting library
import seaborn as sns
import matplotlib.pyplot as plt

# Set font sizes in plots
sns.set(font_scale = 2)
# Display all columns
pd.set_option('display.max_columns', None)

Hitters = pd.read_csv("../data/Hitters.csv")
Hitters
     AtBat  Hits  HmRun  Runs  RBI  Walks  Years  CAtBat  CHits  CHmRun  \
0      293    66      1    30   29     14      1     293     66       1   
1      315    81      7    24   38     39     14    3449    835      69   
2      479   130     18    66   72     76      3    1624    457      63   
3      496   141     20    65   78     37     11    5628   1575     225   
4      321    87     10    39   42     30      2     396    101      12   
..     ...   ...    ...   ...  ...    ...    ...     ...    ...     ...   
317    497   127      7    65   48     37      5    2703    806      32   
318    492   136      5    76   50     94     12    5511   1511      39   
319    475   126      3    61   43     52      6    1700    433       7   
320    573   144      9    85   60     78      8    3198    857      97   
321    631   170      9    77   44     31     11    4908   1457      30   

     CRuns  CRBI  CWalks League Division  PutOuts  Assists  Errors  Salary  \
0       30    29      14      A        E      446       33      20     NaN   
1      321   414     375      N        W      632       43      10   475.0   
2      224   266     263      A        W      880       82      14   480.0   
3      828   838     354      N        E      200       11       3   500.0   
4       48    46      33      N        E      805       40       4    91.5   
..     ...   ...     ...    ...      ...      ...      ...     ...     ...   
317    379   311     138      N        E      325        9       3   700.0   
318    897   451     875      A        E      313      381      20   875.0   
319    217    93     146      A        W       37      113       7   385.0   
320    470   420     332      A        E     1314      131      12   960.0   
321    775   357     249      A        W      408        4       3  1000.0   

    NewLeague  
0           A  
1           N  
2           A  
3           N  
4           N  
..        ...  
317         N  
318         A  
319         A  
320         A  
321         A  

[322 rows x 20 columns]
# Numerical summaries
Hitters.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 322 entries, 0 to 321
Data columns (total 20 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   AtBat      322 non-null    int64  
 1   Hits       322 non-null    int64  
 2   HmRun      322 non-null    int64  
 3   Runs       322 non-null    int64  
 4   RBI        322 non-null    int64  
 5   Walks      322 non-null    int64  
 6   Years      322 non-null    int64  
 7   CAtBat     322 non-null    int64  
 8   CHits      322 non-null    int64  
 9   CHmRun     322 non-null    int64  
 10  CRuns      322 non-null    int64  
 11  CRBI       322 non-null    int64  
 12  CWalks     322 non-null    int64  
 13  League     322 non-null    object 
 14  Division   322 non-null    object 
 15  PutOuts    322 non-null    int64  
 16  Assists    322 non-null    int64  
 17  Errors     322 non-null    int64  
 18  Salary     263 non-null    float64
 19  NewLeague  322 non-null    object 
dtypes: float64(1), int64(16), object(3)
memory usage: 50.4+ KB
Hitters.describe()
            AtBat        Hits       HmRun        Runs         RBI       Walks  \
count  322.000000  322.000000  322.000000  322.000000  322.000000  322.000000   
mean   380.928571  101.024845   10.770186   50.909938   48.027950   38.742236   
std    153.404981   46.454741    8.709037   26.024095   26.166895   21.639327   
min     16.000000    1.000000    0.000000    0.000000    0.000000    0.000000   
25%    255.250000   64.000000    4.000000   30.250000   28.000000   22.000000   
50%    379.500000   96.000000    8.000000   48.000000   44.000000   35.000000   
75%    512.000000  137.000000   16.000000   69.000000   64.750000   53.000000   
max    687.000000  238.000000   40.000000  130.000000  121.000000  105.000000   

            Years       CAtBat        CHits      CHmRun        CRuns  \
count  322.000000    322.00000   322.000000  322.000000   322.000000   
mean     7.444099   2648.68323   717.571429   69.490683   358.795031   
std      4.926087   2324.20587   654.472627   86.266061   334.105886   
min      1.000000     19.00000     4.000000    0.000000     1.000000   
25%      4.000000    816.75000   209.000000   14.000000   100.250000   
50%      6.000000   1928.00000   508.000000   37.500000   247.000000   
75%     11.000000   3924.25000  1059.250000   90.000000   526.250000   
max     24.000000  14053.00000  4256.000000  548.000000  2165.000000   

              CRBI       CWalks      PutOuts     Assists      Errors  \
count   322.000000   322.000000   322.000000  322.000000  322.000000   
mean    330.118012   260.239130   288.937888  106.913043    8.040373   
std     333.219617   267.058085   280.704614  136.854876    6.368359   
min       0.000000     0.000000     0.000000    0.000000    0.000000   
25%      88.750000    67.250000   109.250000    7.000000    3.000000   
50%     220.500000   170.500000   212.000000   39.500000    6.000000   
75%     426.250000   339.250000   325.000000  166.000000   11.000000   
max    1659.000000  1566.000000  1378.000000  492.000000   32.000000   

            Salary  
count   263.000000  
mean    535.925882  
std     451.118681  
min      67.500000  
25%     190.000000  
50%     425.000000  
75%     750.000000  
max    2460.000000  

Graphical summary takes longer to run so suppressed here.

# Graphical summaries
sns.pairplot(data = Hitters)

There are 59 NAs for the salary. Let’s drop those cases. We are left with 263 data points.

Hitters.dropna(inplace = True)
Hitters.shape
(263, 20)
library(GGally)
library(ISLR2)
library(tidymodels)
library(tidyverse)

Hitters <- as_tibble(Hitters) %>% print(width = Inf)
# A tibble: 322 × 20
   AtBat  Hits HmRun  Runs   RBI Walks Years CAtBat CHits CHmRun CRuns  CRBI
   <int> <int> <int> <int> <int> <int> <int>  <int> <int>  <int> <int> <int>
 1   293    66     1    30    29    14     1    293    66      1    30    29
 2   315    81     7    24    38    39    14   3449   835     69   321   414
 3   479   130    18    66    72    76     3   1624   457     63   224   266
 4   496   141    20    65    78    37    11   5628  1575    225   828   838
 5   321    87    10    39    42    30     2    396   101     12    48    46
 6   594   169     4    74    51    35    11   4408  1133     19   501   336
 7   185    37     1    23     8    21     2    214    42      1    30     9
 8   298    73     0    24    24     7     3    509   108      0    41    37
 9   323    81     6    26    32     8     2    341    86      6    32    34
10   401    92    17    49    66    65    13   5206  1332    253   784   890
   CWalks League Division PutOuts Assists Errors Salary NewLeague
    <int> <fct>  <fct>      <int>   <int>  <int>  <dbl> <fct>    
 1     14 A      E            446      33     20   NA   A        
 2    375 N      W            632      43     10  475   N        
 3    263 A      W            880      82     14  480   A        
 4    354 N      E            200      11      3  500   N        
 5     33 N      E            805      40      4   91.5 N        
 6    194 A      W            282     421     25  750   A        
 7     24 N      E             76     127      7   70   A        
 8     12 A      W            121     283      9  100   A        
 9      8 N      W            143     290     19   75   N        
10    866 A      E              0       0      0 1100   A        
# … with 312 more rows
# Numerical summaries
summary(Hitters)
     AtBat            Hits         HmRun            Runs       
 Min.   : 16.0   Min.   :  1   Min.   : 0.00   Min.   :  0.00  
 1st Qu.:255.2   1st Qu.: 64   1st Qu.: 4.00   1st Qu.: 30.25  
 Median :379.5   Median : 96   Median : 8.00   Median : 48.00  
 Mean   :380.9   Mean   :101   Mean   :10.77   Mean   : 50.91  
 3rd Qu.:512.0   3rd Qu.:137   3rd Qu.:16.00   3rd Qu.: 69.00  
 Max.   :687.0   Max.   :238   Max.   :40.00   Max.   :130.00  
                                                               
      RBI             Walks            Years            CAtBat       
 Min.   :  0.00   Min.   :  0.00   Min.   : 1.000   Min.   :   19.0  
 1st Qu.: 28.00   1st Qu.: 22.00   1st Qu.: 4.000   1st Qu.:  816.8  
 Median : 44.00   Median : 35.00   Median : 6.000   Median : 1928.0  
 Mean   : 48.03   Mean   : 38.74   Mean   : 7.444   Mean   : 2648.7  
 3rd Qu.: 64.75   3rd Qu.: 53.00   3rd Qu.:11.000   3rd Qu.: 3924.2  
 Max.   :121.00   Max.   :105.00   Max.   :24.000   Max.   :14053.0  
                                                                     
     CHits            CHmRun           CRuns             CRBI        
 Min.   :   4.0   Min.   :  0.00   Min.   :   1.0   Min.   :   0.00  
 1st Qu.: 209.0   1st Qu.: 14.00   1st Qu.: 100.2   1st Qu.:  88.75  
 Median : 508.0   Median : 37.50   Median : 247.0   Median : 220.50  
 Mean   : 717.6   Mean   : 69.49   Mean   : 358.8   Mean   : 330.12  
 3rd Qu.:1059.2   3rd Qu.: 90.00   3rd Qu.: 526.2   3rd Qu.: 426.25  
 Max.   :4256.0   Max.   :548.00   Max.   :2165.0   Max.   :1659.00  
                                                                     
     CWalks        League  Division    PutOuts          Assists     
 Min.   :   0.00   A:175   E:157    Min.   :   0.0   Min.   :  0.0  
 1st Qu.:  67.25   N:147   W:165    1st Qu.: 109.2   1st Qu.:  7.0  
 Median : 170.50                    Median : 212.0   Median : 39.5  
 Mean   : 260.24                    Mean   : 288.9   Mean   :106.9  
 3rd Qu.: 339.25                    3rd Qu.: 325.0   3rd Qu.:166.0  
 Max.   :1566.00                    Max.   :1378.0   Max.   :492.0  
                                                                    
     Errors          Salary       NewLeague
 Min.   : 0.00   Min.   :  67.5   A:176    
 1st Qu.: 3.00   1st Qu.: 190.0   N:146    
 Median : 6.00   Median : 425.0            
 Mean   : 8.04   Mean   : 535.9            
 3rd Qu.:11.00   3rd Qu.: 750.0            
 Max.   :32.00   Max.   :2460.0            
                 NA's   :59                

Graphical summary takes longer to run so suppressed here.

# Graphical summaries
ggpairs(
  data = Hitters, 
  mapping = aes(alpha = 0.25), 
  lower = list(continuous = "smooth")
  ) + 
  labs(title = "Hitters Data")

There are 59 NAs for the salary. Let’s drop those cases. We are left with 263 data points.

Hitters <- Hitters %>%
  drop_na()
dim(Hitters)
[1] 263  20

3 Initial split into test and non-test sets

from sklearn.model_selection import train_test_split

Hitters_other, Hitters_test = train_test_split(
  Hitters, 
  train_size = 0.75,
  random_state = 425, # seed
  )
Hitters_test.shape
(66, 20)
Hitters_other.shape
(197, 20)

Separate \(X\) and \(y\).

# Non-test X and y
X_other = Hitters_other.drop('Salary', axis = 1)
y_other = Hitters_other.Salary
# Test X and y
X_test = Hitters_test.drop('Salary', axis = 1)
y_test = Hitters_test.Salary
# For reproducibility
set.seed(425)
data_split <- initial_split(
  Hitters, 
  # # stratify by percentilesk
  # strata = "Salary", 
  prop = 0.75
  )

Hitters_other <- training(data_split)
dim(Hitters_other)
[1] 197  20
Hitters_test <- testing(data_split)
dim(Hitters_test)
[1] 66 20

4 Preprocessing (Python) or recipe (R)

For regularization methods such as ridge and lasso, it is essential to center and scale predictors.

Pre-processor for one-hot coding of categorical variables and then standardizing all numeric predictors.

from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import make_column_transformer
from sklearn.pipeline import Pipeline

# OHE transformer for categorical variables
cattf = make_column_transformer(
  (OneHotEncoder(drop = 'first'), ['League', 'Division', 'NewLeague']),
  remainder = 'passthrough'
)
# Standardization transformer
scalar = StandardScaler()
norm_recipe <- 
  recipe(
    Salary ~ ., 
    data = Hitters_other
  ) %>%
  # create traditional dummy variables
  step_dummy(all_nominal()) %>%
  # zero-variance filter
  step_zv(all_predictors()) %>% 
  # center and scale numeric data
  step_normalize(all_predictors()) %>%
  # step_log(Salary, base = 10) %>%
  # estimate the means and standard deviations
  prep(training = Hitters_other, retain = TRUE)
norm_recipe
Recipe

Inputs:

      role #variables
   outcome          1
 predictor         19

Training data contained 197 data points and no missing data.

Operations:

Dummy variables from League, Division, NewLeague [trained]
Zero variance filter removed <none> [trained]
Centering and scaling for AtBat, Hits, HmRun, Runs, RBI, Walks, Years, CA... [trained]

5 Model

from sklearn.linear_model import Lasso

lasso = Lasso(max_iter = 10000)
lasso
Lasso(max_iter=10000)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
lasso_mod <- 
  # mixture = 0 (ridge), mixture = 1 (lasso)
  linear_reg(penalty = tune(), mixture = 1.0) %>% 
  set_engine("glmnet")
lasso_mod
Linear Regression Model Specification (regression)

Main Arguments:
  penalty = tune()
  mixture = 1

Computational engine: glmnet 

6 Pipeline (Python) or workflow (R)

Here we bundle the preprocessing step (Python) or recipe (R) and model.

pipe = Pipeline(steps = [
  ("cat_tf", cattf),
  ("std_tf", scalar), 
  ("model", lasso)
  ])
pipe
Pipeline(steps=[('cat_tf',
                 ColumnTransformer(remainder='passthrough',
                                   transformers=[('onehotencoder',
                                                  OneHotEncoder(drop='first'),
                                                  ['League', 'Division',
                                                   'NewLeague'])])),
                ('std_tf', StandardScaler()),
                ('model', Lasso(max_iter=10000))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
lr_wf <- 
  workflow() %>%
  add_model(lasso_mod) %>%
  add_recipe(norm_recipe)
lr_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_dummy()
• step_zv()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)

Main Arguments:
  penalty = tune()
  mixture = 1

Computational engine: glmnet 

7 Tuning grid

Set up the grid for tuning in the range of \(10^{-2}-10^3\).

# Tune hyper-parameter(s)
alphas = np.logspace(start = -3, stop = 2, base = 10, num = 100)
tuned_parameters = {"model__alpha": alphas}
lambda_grid <-
  grid_regular(penalty(range = c(-2, 3), trans = log10_trans()), levels = 100)
lambda_grid
# A tibble: 100 × 1
   penalty
     <dbl>
 1  0.01  
 2  0.0112
 3  0.0126
 4  0.0142
 5  0.0159
 6  0.0179
 7  0.0201
 8  0.0226
 9  0.0254
10  0.0285
# … with 90 more rows

8 Cross-validation (CV)

Set up CV partitions and CV criterion.

from sklearn.model_selection import GridSearchCV

# Set up CV
n_folds = 10
search = GridSearchCV(
  pipe, 
  tuned_parameters, 
  cv = n_folds, 
  scoring = "neg_root_mean_squared_error",
  # Refit the best model on the whole data set
  refit = True 
  )

Fit CV. This is typically the most time-consuming step.

# Fit CV
search.fit(X_other, y_other)
GridSearchCV(cv=10,
             estimator=Pipeline(steps=[('cat_tf',
                                        ColumnTransformer(remainder='passthrough',
                                                          transformers=[('onehotencoder',
                                                                         OneHotEncoder(drop='first'),
                                                                         ['League',
                                                                          'Division',
                                                                          'NewLeague'])])),
                                       ('std_tf', StandardScaler()),
                                       ('model', Lasso(max_iter=10000))]),
             param_grid={'model__alpha': array([1.00000000e-03, 1.12332403e-03, 1.26185688e-03, 1.41747416e-03,...
       6.89261210e+00, 7.74263683e+00, 8.69749003e+00, 9.77009957e+00,
       1.09749877e+01, 1.23284674e+01, 1.38488637e+01, 1.55567614e+01,
       1.74752840e+01, 1.96304065e+01, 2.20513074e+01, 2.47707636e+01,
       2.78255940e+01, 3.12571585e+01, 3.51119173e+01, 3.94420606e+01,
       4.43062146e+01, 4.97702356e+01, 5.59081018e+01, 6.28029144e+01,
       7.05480231e+01, 7.92482898e+01, 8.90215085e+01, 1.00000000e+02])},
             scoring='neg_root_mean_squared_error')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

CV results.

cv_res = pd.DataFrame({
  "alpha": alphas,
  "rmse": -search.cv_results_["mean_test_score"]
  })

plt.figure()
sns.relplot(
  data = cv_res,
  x = "alpha",
  y = "rmse"
  ).set(
    xlabel = "alpha",
    ylabel = "CV RMSE",
    xscale = "log"
);
plt.show()

Best CV RMSE:

-search.best_score_
327.5225980405363

Set cross-validation partitions.

set.seed(250)
folds <- vfold_cv(Hitters_other, v = 10)
folds
#  10-fold cross-validation 
# A tibble: 10 × 2
   splits           id    
   <list>           <chr> 
 1 <split [177/20]> Fold01
 2 <split [177/20]> Fold02
 3 <split [177/20]> Fold03
 4 <split [177/20]> Fold04
 5 <split [177/20]> Fold05
 6 <split [177/20]> Fold06
 7 <split [177/20]> Fold07
 8 <split [178/19]> Fold08
 9 <split [178/19]> Fold09
10 <split [178/19]> Fold10

Fit cross-validation.

lasso_fit <- 
  lr_wf %>%
  tune_grid(
    resamples = folds,
    grid = lambda_grid
    )
lasso_fit
# Tuning results
# 10-fold cross-validation 
# A tibble: 10 × 4
   splits           id     .metrics           .notes          
   <list>           <chr>  <list>             <list>          
 1 <split [177/20]> Fold01 <tibble [200 × 5]> <tibble [1 × 3]>
 2 <split [177/20]> Fold02 <tibble [200 × 5]> <tibble [1 × 3]>
 3 <split [177/20]> Fold03 <tibble [200 × 5]> <tibble [1 × 3]>
 4 <split [177/20]> Fold04 <tibble [200 × 5]> <tibble [1 × 3]>
 5 <split [177/20]> Fold05 <tibble [200 × 5]> <tibble [1 × 3]>
 6 <split [177/20]> Fold06 <tibble [200 × 5]> <tibble [1 × 3]>
 7 <split [177/20]> Fold07 <tibble [200 × 5]> <tibble [1 × 3]>
 8 <split [178/19]> Fold08 <tibble [200 × 5]> <tibble [1 × 3]>
 9 <split [178/19]> Fold09 <tibble [200 × 5]> <tibble [1 × 3]>
10 <split [178/19]> Fold10 <tibble [200 × 5]> <tibble [1 × 3]>

There were issues with some computations:

  - Warning(s) x10: A correlation computation is required, but `estimate` is constant...

Run `show_notes(.Last.tune.result)` for more information.

Visualize CV criterion.

lasso_fit %>%
  collect_metrics() %>%
  print(width = Inf) %>%
  filter(.metric == "rmse") %>%
  ggplot(mapping = aes(x = penalty, y = mean)) + 
  geom_point() + 
  geom_line() + 
  labs(x = "Penalty", y = "CV RMSE") + 
  scale_x_log10(labels = scales::label_number())
# A tibble: 200 × 7
   penalty .metric .estimator    mean     n std_err .config               
     <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>                 
 1  0.01   rmse    standard   341.       10 33.2    Preprocessor1_Model001
 2  0.01   rsq     standard     0.472    10  0.0845 Preprocessor1_Model001
 3  0.0112 rmse    standard   341.       10 33.2    Preprocessor1_Model002
 4  0.0112 rsq     standard     0.472    10  0.0845 Preprocessor1_Model002
 5  0.0126 rmse    standard   341.       10 33.2    Preprocessor1_Model003
 6  0.0126 rsq     standard     0.472    10  0.0845 Preprocessor1_Model003
 7  0.0142 rmse    standard   341.       10 33.2    Preprocessor1_Model004
 8  0.0142 rsq     standard     0.472    10  0.0845 Preprocessor1_Model004
 9  0.0159 rmse    standard   341.       10 33.2    Preprocessor1_Model005
10  0.0159 rsq     standard     0.472    10  0.0845 Preprocessor1_Model005
# … with 190 more rows

Show the top 5 models (\(\lambda\) values)

lasso_fit %>%
  show_best("rmse")
# A tibble: 5 × 7
  penalty .metric .estimator  mean     n std_err .config               
    <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 
1    4.23 rmse    standard    338.    10    38.1 Preprocessor1_Model053
2    4.75 rmse    standard    338.    10    38.3 Preprocessor1_Model054
3    3.76 rmse    standard    338.    10    37.9 Preprocessor1_Model052
4    3.35 rmse    standard    339.    10    37.6 Preprocessor1_Model051
5    5.34 rmse    standard    339.    10    38.6 Preprocessor1_Model055

Let’s select the best model

best_lasso <- lasso_fit %>%
  select_best("rmse")
best_lasso
# A tibble: 1 × 2
  penalty .config               
    <dbl> <chr>                 
1    4.23 Preprocessor1_Model053

9 Finalize our model

Now we are done tuning. Finally, let’s fit this final model to the whole training data and use our test data to estimate the model performance we expect to see with new data.

Since we called GridSearchCV with refit = True, the best model fit on the whole non-test data is readily available.

search.best_estimator_
Pipeline(steps=[('cat_tf',
                 ColumnTransformer(remainder='passthrough',
                                   transformers=[('onehotencoder',
                                                  OneHotEncoder(drop='first'),
                                                  ['League', 'Division',
                                                   'NewLeague'])])),
                ('std_tf', StandardScaler()),
                ('model', Lasso(alpha=0.4750810162102798, max_iter=10000))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

The final prediction RMSE on the test set is

from sklearn.metrics import mean_squared_error

mean_squared_error(y_test, search.best_estimator_.predict(X_test), squared = False)
359.5173598142268

Test RMSE seems to be a bit off.

# Final workflow
final_wf <- lr_wf %>%
  finalize_workflow(best_lasso)
final_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_dummy()
• step_zv()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)

Main Arguments:
  penalty = 4.2292428743895
  mixture = 1

Computational engine: glmnet 
# Fit the whole training set, then predict the test cases
final_fit <- 
  final_wf %>%
  last_fit(data_split)
final_fit
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits           id               .metrics .notes   .predictions .workflow 
  <list>           <chr>            <list>   <list>   <list>       <list>    
1 <split [197/66]> train/test split <tibble> <tibble> <tibble>     <workflow>
# Test metrics
final_fit %>% collect_metrics()
# A tibble: 2 × 4
  .metric .estimator .estimate .config             
  <chr>   <chr>          <dbl> <chr>               
1 rmse    standard     319.    Preprocessor1_Model1
2 rsq     standard       0.411 Preprocessor1_Model1