Machine Learning Workflow: Regression Trees

Econ 425T

Author

Dr. Hua Zhou @ UCLA

Published

February 20, 2023

Display system information for reproducibility.

import IPython
print(IPython.sys_info())
{'commit_hash': 'add5877a4',
 'commit_source': 'installation',
 'default_encoding': 'utf-8',
 'ipython_path': '/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/IPython',
 'ipython_version': '8.8.0',
 'os_name': 'posix',
 'platform': 'macOS-10.16-x86_64-i386-64bit',
 'sys_executable': '/Library/Frameworks/Python.framework/Versions/3.10/bin/python3',
 'sys_platform': 'darwin',
 'sys_version': '3.10.9 (v3.10.9:1dd9be6584, Dec  6 2022, 14:37:36) [Clang '
                '13.0.0 (clang-1300.0.29.30)]'}

1 Overview

We illustrate the typical machine learning workflow for regression trees using the Hitters data set from R ISLR2 package.

  1. Initial splitting to test and non-test sets.

  2. Pre-processing of data: not much is needed for regression trees.

  3. Tune the cost complexity pruning hyper-parameter(s) using 10-fold cross-validation (CV) on the non-test data.

  4. Choose the best model by CV and refit it on the whole non-test data.

  5. Final prediction on the test data.

2 Hitters data

A documentation of the Hitters data is here. The goal is to predict the log(Salary) (at opening of 1987 season) of MLB players from their performance metrics in the 1986-7 season.

# Load the pandas library
import pandas as pd
# Load numpy for array manipulation
import numpy as np
# Load seaborn plotting library
import seaborn as sns
import matplotlib.pyplot as plt

# Set font sizes in plots
sns.set(font_scale = 1.2)
# Display all columns
pd.set_option('display.max_columns', None)

Hitters = pd.read_csv("../data/Hitters.csv")
Hitters
     AtBat  Hits  HmRun  Runs  RBI  Walks  Years  CAtBat  CHits  CHmRun  \
0      293    66      1    30   29     14      1     293     66       1   
1      315    81      7    24   38     39     14    3449    835      69   
2      479   130     18    66   72     76      3    1624    457      63   
3      496   141     20    65   78     37     11    5628   1575     225   
4      321    87     10    39   42     30      2     396    101      12   
..     ...   ...    ...   ...  ...    ...    ...     ...    ...     ...   
317    497   127      7    65   48     37      5    2703    806      32   
318    492   136      5    76   50     94     12    5511   1511      39   
319    475   126      3    61   43     52      6    1700    433       7   
320    573   144      9    85   60     78      8    3198    857      97   
321    631   170      9    77   44     31     11    4908   1457      30   

     CRuns  CRBI  CWalks League Division  PutOuts  Assists  Errors  Salary  \
0       30    29      14      A        E      446       33      20     NaN   
1      321   414     375      N        W      632       43      10   475.0   
2      224   266     263      A        W      880       82      14   480.0   
3      828   838     354      N        E      200       11       3   500.0   
4       48    46      33      N        E      805       40       4    91.5   
..     ...   ...     ...    ...      ...      ...      ...     ...     ...   
317    379   311     138      N        E      325        9       3   700.0   
318    897   451     875      A        E      313      381      20   875.0   
319    217    93     146      A        W       37      113       7   385.0   
320    470   420     332      A        E     1314      131      12   960.0   
321    775   357     249      A        W      408        4       3  1000.0   

    NewLeague  
0           A  
1           N  
2           A  
3           N  
4           N  
..        ...  
317         N  
318         A  
319         A  
320         A  
321         A  

[322 rows x 20 columns]
# Numerical summaries
Hitters.describe()
            AtBat        Hits       HmRun        Runs         RBI       Walks  \
count  322.000000  322.000000  322.000000  322.000000  322.000000  322.000000   
mean   380.928571  101.024845   10.770186   50.909938   48.027950   38.742236   
std    153.404981   46.454741    8.709037   26.024095   26.166895   21.639327   
min     16.000000    1.000000    0.000000    0.000000    0.000000    0.000000   
25%    255.250000   64.000000    4.000000   30.250000   28.000000   22.000000   
50%    379.500000   96.000000    8.000000   48.000000   44.000000   35.000000   
75%    512.000000  137.000000   16.000000   69.000000   64.750000   53.000000   
max    687.000000  238.000000   40.000000  130.000000  121.000000  105.000000   

            Years       CAtBat        CHits      CHmRun        CRuns  \
count  322.000000    322.00000   322.000000  322.000000   322.000000   
mean     7.444099   2648.68323   717.571429   69.490683   358.795031   
std      4.926087   2324.20587   654.472627   86.266061   334.105886   
min      1.000000     19.00000     4.000000    0.000000     1.000000   
25%      4.000000    816.75000   209.000000   14.000000   100.250000   
50%      6.000000   1928.00000   508.000000   37.500000   247.000000   
75%     11.000000   3924.25000  1059.250000   90.000000   526.250000   
max     24.000000  14053.00000  4256.000000  548.000000  2165.000000   

              CRBI       CWalks      PutOuts     Assists      Errors  \
count   322.000000   322.000000   322.000000  322.000000  322.000000   
mean    330.118012   260.239130   288.937888  106.913043    8.040373   
std     333.219617   267.058085   280.704614  136.854876    6.368359   
min       0.000000     0.000000     0.000000    0.000000    0.000000   
25%      88.750000    67.250000   109.250000    7.000000    3.000000   
50%     220.500000   170.500000   212.000000   39.500000    6.000000   
75%     426.250000   339.250000   325.000000  166.000000   11.000000   
max    1659.000000  1566.000000  1378.000000  492.000000   32.000000   

            Salary  
count   263.000000  
mean    535.925882  
std     451.118681  
min      67.500000  
25%     190.000000  
50%     425.000000  
75%     750.000000  
max    2460.000000  

Graphical summary:

# Graphical summaries
plt.figure()
sns.pairplot(data = Hitters);
plt.show()

There are 59 NAs for the salary. Let’s drop those cases. We are left with 263 data points.

Hitters.dropna(inplace = True)
Hitters.shape
(263, 20)

3 Initial split into test and non-test sets

We randomly split the data in half of test data and another half of non-test data.

from sklearn.model_selection import train_test_split

Hitters_other, Hitters_test = train_test_split(
  Hitters, 
  train_size = 0.5,
  random_state = 425, # seed
  )
Hitters_test.shape
(132, 20)
Hitters_other.shape
(131, 20)

Separate \(X\) and \(y\). We will use 9 of the features.

features = ['Assists', 'AtBat', 'Hits', 'HmRun', 'PutOuts', 'RBI', 'Runs', 'Walks', 'Years']
# Non-test X and y
X_other = Hitters_other[features]
y_other = np.log(Hitters_other.Salary)
# Test X and y
X_test = Hitters_test[features]
y_test = np.log(Hitters_test.Salary)

4 Preprocessing (Python) or recipe (R)

Not much preprocessing is needed here since all predictors are quantitative.

5 Model

from sklearn.tree import DecisionTreeRegressor, plot_tree

regtree_mod = DecisionTreeRegressor(random_state = 425)

6 Pipeline (Python) or workflow (R)

Here we bundle the preprocessing step (Python) or recipe (R) and model.

from sklearn.pipeline import Pipeline

pipe = Pipeline(steps = [
  ("model", regtree_mod)
  ])
pipe
Pipeline(steps=[('model', DecisionTreeRegressor(random_state=425))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

7 Tuning grid

ccp_alpha is the Minimal Cost-Complexity Pruning parameter. Greater values of ccp_alpha increase the number of nodes pruned.

# Tune hyper-parameter(s)
ccp_alpha_grid = np.linspace(start = 0.0, stop = 0.1, num = 100)
tuned_parameters = {
  "model__ccp_alpha": ccp_alpha_grid
  }
tuned_parameters  
{'model__ccp_alpha': array([0.        , 0.0010101 , 0.0020202 , 0.0030303 , 0.0040404 ,
       0.00505051, 0.00606061, 0.00707071, 0.00808081, 0.00909091,
       0.01010101, 0.01111111, 0.01212121, 0.01313131, 0.01414141,
       0.01515152, 0.01616162, 0.01717172, 0.01818182, 0.01919192,
       0.02020202, 0.02121212, 0.02222222, 0.02323232, 0.02424242,
       0.02525253, 0.02626263, 0.02727273, 0.02828283, 0.02929293,
       0.03030303, 0.03131313, 0.03232323, 0.03333333, 0.03434343,
       0.03535354, 0.03636364, 0.03737374, 0.03838384, 0.03939394,
       0.04040404, 0.04141414, 0.04242424, 0.04343434, 0.04444444,
       0.04545455, 0.04646465, 0.04747475, 0.04848485, 0.04949495,
       0.05050505, 0.05151515, 0.05252525, 0.05353535, 0.05454545,
       0.05555556, 0.05656566, 0.05757576, 0.05858586, 0.05959596,
       0.06060606, 0.06161616, 0.06262626, 0.06363636, 0.06464646,
       0.06565657, 0.06666667, 0.06767677, 0.06868687, 0.06969697,
       0.07070707, 0.07171717, 0.07272727, 0.07373737, 0.07474747,
       0.07575758, 0.07676768, 0.07777778, 0.07878788, 0.07979798,
       0.08080808, 0.08181818, 0.08282828, 0.08383838, 0.08484848,
       0.08585859, 0.08686869, 0.08787879, 0.08888889, 0.08989899,
       0.09090909, 0.09191919, 0.09292929, 0.09393939, 0.09494949,
       0.0959596 , 0.0969697 , 0.0979798 , 0.0989899 , 0.1       ])}

8 Cross-validation (CV)

Set up CV partitions and CV criterion.

from sklearn.model_selection import GridSearchCV

# Set up CV
n_folds = 6
search = GridSearchCV(
  pipe,
  tuned_parameters,
  cv = n_folds, 
  scoring = "neg_root_mean_squared_error",
  # Refit the best model on the whole data set
  refit = True
  )

Fit CV. This is typically the most time-consuming step.

# Fit CV
search.fit(X_other, y_other)
GridSearchCV(cv=6,
             estimator=Pipeline(steps=[('model',
                                        DecisionTreeRegressor(random_state=425))]),
             param_grid={'model__ccp_alpha': array([0.        , 0.0010101 , 0.0020202 , 0.0030303 , 0.0040404 ,
       0.00505051, 0.00606061, 0.00707071, 0.00808081, 0.00909091,
       0.01010101, 0.01111111, 0.01212121, 0.01313131, 0.01414141,
       0.01515152, 0.01616162, 0.01717172, 0.01818182, 0.01919192,
       0.020202...
       0.07070707, 0.07171717, 0.07272727, 0.07373737, 0.07474747,
       0.07575758, 0.07676768, 0.07777778, 0.07878788, 0.07979798,
       0.08080808, 0.08181818, 0.08282828, 0.08383838, 0.08484848,
       0.08585859, 0.08686869, 0.08787879, 0.08888889, 0.08989899,
       0.09090909, 0.09191919, 0.09292929, 0.09393939, 0.09494949,
       0.0959596 , 0.0969697 , 0.0979798 , 0.0989899 , 0.1       ])},
             scoring='neg_root_mean_squared_error')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Visualize CV results.

Code
cv_res = pd.DataFrame({
  "ccp_alpha": np.array(search.cv_results_["param_model__ccp_alpha"]),
  "rmse": -search.cv_results_["mean_test_score"]
  })

plt.figure()
sns.relplot(
  # kind = "line",
  data = cv_res,
  x = "ccp_alpha",
  y = "rmse"
  ).set(
    xlabel = "CCP Alpha",
    ylabel = "CV RMSE"
);
plt.show()

Best CV RMSE:

-search.best_score_
0.5217085199223737

9 Finalize our model

Now we are done tuning. Finally, let’s fit this final model to the whole training data and use our test data to estimate the model performance we expect to see with new data.

Since we called GridSearchCV with refit = True, the best model fit on the whole non-test data is readily available.

search.best_estimator_
Pipeline(steps=[('model',
                 DecisionTreeRegressor(ccp_alpha=0.03636363636363636,
                                       random_state=425))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Visualize the best regression tree.

plt.figure()
plot_tree(
  search.best_estimator_['model'],
  feature_names = features
  );
plt.show()

Feature importances:

vi_df = pd.DataFrame({
  "feature": features,
  "vi": search.best_estimator_['model'].feature_importances_
  })

plt.figure()
sns.barplot(
  data = vi_df,
  x = "feature",
  y = "vi"
  ).set(
    xlabel = "Feature",
    ylabel = "VI"
);
plt.xticks(rotation = 90);
plt.show()

The final prediction RMSE on the test set is

from sklearn.metrics import mean_squared_error

mean_squared_error(
  y_test, 
  search.best_estimator_.predict(X_test), 
  squared = False
  )
0.5698068112280658