This page was generated from docs/tutorials/basics.ipynb. Interactive online version: Binder badge.

Basics#

[ ]:
import sys
sys.path.append('..')

import os
n_cores = int(8)
os.environ["OMP_NUM_THREADS"] = f"{n_cores}"
os.environ["OPENBLAS_NUM_THREADS"] = f"{n_cores}"
os.environ["MKL_NUM_THREADS"] = f"{n_cores}"
os.environ["VECLIB_MAXIMUM_THREADS"] = f"{n_cores}"
os.environ["NUMEXPR_NUM_THREADS"] = f"{n_cores}"
os.environ["NUMBA_CACHE_DIR"]='/tmp/numba_cache'

import numpy as np
from sklearn_ensemble_cv import reset_random_seeds
import matplotlib.pyplot as plt

reset_random_seeds(0)

We make up some fake data for illustration.

[2]:
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

X, y = make_regression(n_samples=300, n_features=200,
                       n_informative=5, n_targets=1,
                       random_state=0, shuffle=False)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

The Ensemble class#

For users who want to have more control of the ensemble predictor, this section introduces lower-level class and object that we use to implement the cross-validation methods. For users who just want to use easy interface functions, you can safely skip this section.

We provide Ensemble a class for ensemble predictor, whose base class is sklearn.ensemble.BaggingRegressor. This means that the usage of Ensemble is basically the same as the latter, except the new class includes several new member functions that we will illustrate below.

Initialize an object#

The initialization of Ensemble class is the same as sklearn.ensemble.BaggingRegressor, where

  1. The base estimator object, whose hyperparameter kwargs_regr is specified when it is initialized. In the following example, we use decision tree as the base estimator.

  2. The hyperparameters for building an ensemble, such as n_estimators, max_samples, max_features, and etc.

[3]:
from sklearn.tree import DecisionTreeRegressor
from sklearn_ensemble_cv import Ensemble
kwargs_regr = {'max_depth': 7}
kwargs_ensemble = {'max_samples': 0.8}
regr = Ensemble(estimator=DecisionTreeRegressor(**kwargs_regr), n_estimators=100, **kwargs_ensemble)

After the ensemble object regr is initialized, we can fit the data and get the prediction:

[4]:
regr.fit(X_train, y_train)
regr.predict(X_test)
[4]:
array([ 1.55258626e+02, -3.05994688e+01, -3.08717707e+01,  4.59185397e+01,
       -1.31673233e+02, -1.53618401e+01, -7.58969354e+00,  5.55612647e+01,
       -2.49650470e+01, -8.64505794e+01,  5.30470606e+00, -1.65026872e+00,
       -1.68937357e+02, -3.64697321e+01,  1.09240808e+02, -3.61555106e+01,
        1.70139644e+02,  5.07345762e+01, -4.57768586e+01, -6.36637535e+01,
        1.98450292e+01, -5.88202620e+01,  1.32874009e+02, -1.63522372e+02,
       -1.28692376e+02, -5.88109839e+01,  5.19530446e+01,  1.42758358e+01,
        3.30152254e+01, -4.05912157e+01, -5.82506190e+01, -1.24873797e+01,
       -1.75895663e+02,  9.75114150e+01,  8.74697325e+01,  5.13658483e+01,
       -1.27636452e+02, -3.21727865e+01,  1.46534143e+02,  1.51187165e+02,
       -1.23416630e+02, -1.35952080e+01, -4.85622867e+01, -4.10217919e+01,
       -1.45807260e+02,  1.25982796e+02, -6.55997321e+00,  6.87142111e+01,
        6.01303870e+01,  1.59088989e+02, -6.05490517e+01, -1.89612897e+01,
       -7.83841416e+01, -1.98655475e+01, -2.07795599e+00,  9.06000291e+01,
        1.56802764e+02, -4.79037791e+00,  1.63049008e+02, -1.95082888e+02,
        1.98327834e+02,  3.22210898e+01,  1.13020203e+01, -1.23629912e-01,
        8.36659821e+01,  9.17160516e+01, -3.25095987e+01, -1.04825033e+01,
       -3.54178679e+01,  6.38551655e+01,  5.40532228e+01,  2.02374291e+01,
       -7.40274616e+01,  1.32490998e+01, -4.09384762e+01,  1.55996605e+02,
       -3.97031336e+01, -8.84950965e+01, -9.79128997e+01, -1.20241144e+02,
        1.18386965e+02,  1.50236875e+02, -2.09861849e+02, -5.49399208e+01,
       -7.82686509e+01, -1.02451741e+02, -4.25753316e+01, -1.32619362e+02,
       -4.33093105e+01, -4.66723558e+01,  1.65050613e+02, -8.70596762e+01,
        2.30991348e+00,  4.20654775e+01, -1.53247921e+02, -1.57785888e+02,
        1.77328813e+02,  1.17652845e+01,  1.97861320e+01,  7.12165424e+01,
        1.01746011e+02, -2.54085468e+01,  7.05913116e+01,  1.14362733e+01,
       -1.95851324e+02,  2.76538247e+01, -7.24716237e+01, -7.70112213e+01,
       -1.34312439e+02, -1.39850788e+01, -9.38177348e+01, -1.99134689e+01,
       -2.92722574e+01, -1.35896063e+02,  7.67209370e+01,  2.95641797e+01,
       -5.21496990e-01, -1.77364655e+02, -1.27670449e+02, -1.36452786e+02,
        1.19583568e+01, -1.36222497e+02, -8.87966870e+01,  5.22172522e+01,
       -1.30654048e+02,  3.00035105e+01, -2.06507219e+01,  4.73462564e+01,
       -1.50207963e+01,  1.87159008e+01,  4.70190920e+01, -1.53884335e+02,
        3.26657747e+01,  1.06385038e+01,  1.39514153e+02, -1.14965025e+02,
        1.66933169e+02,  1.16198895e+01, -3.71584600e+01, -1.38855162e+01,
       -1.05179205e+02,  3.49003320e+01, -1.11235050e+02, -2.98637151e+01,
        1.74088266e+02,  1.94641882e+02, -6.04350092e+01, -4.83500392e+00,
       -1.43971902e+02,  1.59488916e+02])

Prediction of individual estimators#

We provide a new function predict_individual to obtain prediction values from all estimators in the ensemble. Since we have \(n=150\) observations and \(M=100\) estimators, the resulting prediction would be of shape \((n,M)=(150,100)\).

[5]:
Y_train_hat = regr.predict_individual(X_train)
Y_train_hat.shape
[5]:
(150, 100)

Compute ECV estimate of the prediction risk#

Below, we use function compute_ecv_estimate to estimate the prediction risk for various ensemble size \(M=1,\ldots,100\), only using the first \(M_0=30\) trees.

[6]:
df_est = regr.compute_ecv_estimate(X_train, y_train, M0=30, return_df=True)
df_est
[6]:
M estimate
0 1 16545.259516
1 2 11039.442773
2 3 9204.170525
3 4 8286.534401
4 5 7735.952727
... ... ...
95 96 5648.330545
96 97 5647.148024
97 98 5645.989637
98 99 5644.854651
99 100 5643.742364

100 rows × 2 columns

We can also compute the test error on the test set using function compute_risk.

[7]:
df_risk = regr.compute_risk(X_test, y_test, return_df=True)
df_risk
[7]:
M risk
0 1 15740.161966
1 2 10714.265053
2 3 9699.114654
3 4 8950.480255
4 5 8268.258825
... ... ...
95 96 5875.571676
96 97 5881.180617
97 98 5881.094324
98 99 5882.142533
99 100 5880.883707

100 rows × 2 columns

If we plot both the risk estimate and the actual test error, we see a close match.

[8]:
plt.plot(df_est['M'], df_est['estimate'], label='estimate')
plt.plot(df_risk['M'], df_risk['risk'], label='risk')
plt.legend()
plt.show()
../_images/tutorials_basics_17_0.png

Above we show two basic functions and their usage. One utility of ECV method is that we can also get a risk estimate beyond \(M=100\), which gives us a sense how much improvement one can get if we further increase the ensemble size.

[9]:
df_est = regr.compute_ecv_estimate(X_train, y_train, M_test=1000, M0=30, return_df=True)
df_est
[9]:
M estimate
0 1 16545.259516
1 2 11039.442773
2 3 9204.170525
3 4 8286.534401
4 5 7735.952727
... ... ...
995 996 5544.681887
996 997 5544.670797
997 998 5544.659731
998 999 5544.648686
999 1000 5544.637663

1000 rows × 2 columns

ECV estimate for one configuration of ensemble predictors#

The function comp_empirical_ecv provides an easy way to fit and get risk estimate from ECV. Similar to the previous section, one need to provide

  1. Data: X_train, y_train.

  2. A regressor class and the parameters to initialize it: DecisionTreeRegressor, kwargs_regr=kwargs_regr.

  3. The parameters for building the ensemble (with M denoting n_estimators): kwargs_ensemble=kwargs_ensemble, M=50.

  4. Extra optional parameters for ECV.

The function returns two objects:

  1. An ensemble predictor (an object of Ensemble)

  2. A np.array or pd.DataFrame object, containing the risk estimates given by ECV.

[10]:
from sklearn_ensemble_cv import comp_empirical_ecv

kwargs_regr = {'max_depth': 7}
kwargs_ensemble = {'max_samples': 0.8}
regr, risk_ecv = comp_empirical_ecv(X_train, y_train, DecisionTreeRegressor, kwargs_regr=kwargs_regr,
                   kwargs_ensemble=kwargs_ensemble, M=50)

One can also pass kwargs X_val=X_test, Y_val=y_test to get the actual test errors.

ECV for tuning hyperparameters#

For tuning hyperparameters, such as max_samples and max_features for the ensemble predictors, and max_depth and min_samples_leaf for all base predictors, we can make two grids of these two types of tuning hyperparameters respectively.

We recommend using np.array for each parameter, and make sure to use the correct dtypes that sklearn accept. If one want to set some hyperparameter to a fix value, simply provide it in the grid as either a scalar or a list/array with length one.

[11]:
from sklearn_ensemble_cv import ECV

# Hyperparameters for the base regressor
grid_regr = {
    'max_depth':np.array([6,7], dtype=int),
    }
# Hyperparameters for the ensemble
grid_ensemble = {
    'max_features':np.array([0.9,1.]),
    'max_samples':np.array([0.6,0.7]),
}

res_ecv, info_ecv = ECV(
    X_train, y_train, DecisionTreeRegressor, grid_regr, grid_ensemble,
    X_test=X_test, Y_test=y_test,
    M=50, M0=25, return_df=True
)
[12]:
res_ecv
[12]:
max_depth max_features max_samples risk_val-1 risk_val-2 risk_val-3 risk_val-4 risk_val-5 risk_val-6 risk_val-7 ... risk_test-41 risk_test-42 risk_test-43 risk_test-44 risk_test-45 risk_test-46 risk_test-47 risk_test-48 risk_test-49 risk_test-50
0 6 0.9 0.6 38766.092506 27130.954106 23252.574639 21313.384906 20149.871066 19374.195173 18820.140963 ... 20400.955895 20473.966021 20516.852272 20564.383652 20568.895484 20548.386056 20461.303545 20416.347005 20338.212307 20330.910766
1 6 0.9 0.7 38545.982251 27685.957696 24065.949511 22255.945418 21169.942963 20445.941326 19928.797299 ... 20318.713727 20440.861775 20411.974909 20442.347693 20448.785333 20427.551253 20274.764944 20176.546011 20022.390675 19990.362412
2 6 1.0 0.6 17906.660078 12049.342285 10096.903020 9120.683388 8534.951609 8144.463756 7865.543861 ... 6654.190713 6625.412906 6633.298597 6636.450315 6624.454695 6607.368851 6589.791652 6561.776567 6521.960893 6508.579021
3 6 1.0 0.7 17768.782775 11762.499915 9760.405628 8759.358485 8158.730199 7758.311342 7472.297872 ... 6746.849503 6751.183654 6724.878113 6697.482424 6684.798082 6688.219217 6711.816709 6738.490807 6784.242634 6787.740169
4 7 0.9 0.6 39335.856987 27626.621767 23723.543361 21772.004157 20601.080636 19820.464954 19262.882325 ... 20528.039838 20568.708815 20605.307823 20632.828610 20630.262989 20585.590830 20492.588286 20430.862922 20334.595979 20313.901931
5 7 0.9 0.7 38394.185018 27696.607514 24130.748346 22347.818762 21278.061012 20564.889178 20055.480726 ... 20139.080395 20199.000964 20148.318891 20165.814865 20171.442945 20141.663703 20001.089341 19908.831302 19773.342778 19734.671204
6 7 1.0 0.6 17900.269535 12129.858040 10206.387541 9244.652292 8667.611142 8282.917042 8008.135543 ... 6612.991848 6557.674233 6552.774417 6545.191061 6527.502604 6520.365373 6502.011696 6480.311467 6454.951325 6438.797136
7 7 1.0 0.7 17691.335912 11718.687022 9727.804059 8732.362577 8135.097688 7736.921095 7452.509244 ... 6403.204525 6420.395641 6404.231034 6377.207600 6366.631308 6361.084780 6399.403605 6430.945089 6481.865780 6489.090550

8 rows × 104 columns

[13]:
info_ecv
[13]:
{'best_params_regr': {'max_depth': 7},
 'best_params_ensemble': {'random_state': 0,
  'n_estimators': 50,
  'max_features': 1.0,
  'max_samples': 0.7},
 'best_n_estimators': 50,
 'best_params_index': 7,
 'best_score': 5984.944087677355,
 'delta': 0.0,
 'M_max': inf,
 'best_n_estimators_extrapolate': inf,
 'best_score_extrapolate': 5746.038132076163}
[14]:
res_ecv.iloc[info_ecv['best_params_index']]['risk_test-{}'.format(info_ecv['best_n_estimators'])]
[14]:
6489.090550260457

SplitCV#

[15]:
from sklearn_ensemble_cv import splitCV
res_splitcv, info_splitcv = splitCV(
        X_train, y_train, DecisionTreeRegressor, grid_regr, grid_ensemble,
        M=50, return_df=True, X_test=X_test, Y_test=y_test,
        random_state=0, test_size=0.25,
        )
[16]:
res_splitcv
[16]:
max_depth max_features max_samples risk_val-1 risk_val-2 risk_val-3 risk_val-4 risk_val-5 risk_val-6 risk_val-7 ... risk_test-41 risk_test-42 risk_test-43 risk_test-44 risk_test-45 risk_test-46 risk_test-47 risk_test-48 risk_test-49 risk_test-50
0 6 0.9 0.6 34236.829374 24572.859001 24089.331547 20126.642795 19197.691635 19186.518328 18221.490588 ... 20701.738963 20777.233024 20772.732219 20796.082041 20791.271087 20727.242272 20616.403023 20523.375255 20380.819539 20347.093017
1 6 0.9 0.7 34580.811141 25643.115466 23029.787721 19895.601540 17891.692078 17474.479137 17490.012418 ... 20223.272572 20324.397763 20276.894389 20290.823162 20318.546308 20256.874357 20170.205996 20082.758205 19945.267248 19917.513965
2 6 1.0 0.6 16965.823750 10651.257039 7820.469737 8549.417187 7626.003090 6258.777113 5719.689532 ... 7121.881773 7203.974379 7208.648403 7163.960004 7128.238064 7096.740403 7099.966528 7109.442178 7125.686639 7124.766814
3 6 1.0 0.7 16083.117767 10094.510250 7580.273496 7703.161182 7472.627416 5658.812631 5325.656624 ... 6768.441759 6818.322381 6813.679161 6806.056024 6793.418635 6802.534722 6782.559004 6778.832592 6771.568412 6771.441540
4 7 0.9 0.6 33426.578299 24227.602788 23547.269697 20069.020637 19675.583979 19362.002741 18841.820338 ... 20494.193166 20554.377651 20556.807509 20592.069985 20600.202449 20538.041432 20428.642319 20329.256952 20175.037066 20141.240174
5 7 0.9 0.7 34677.370184 26075.125134 23433.680469 19862.202238 19896.682442 19039.128286 18831.457891 ... 20046.193973 20113.558631 20067.368975 20091.244182 20117.912405 20071.883456 19968.212024 19868.072252 19713.281090 19679.029390
6 7 1.0 0.6 17818.328001 11228.031531 8340.476988 8545.357025 7532.603641 6290.461478 5751.880466 ... 6746.082290 6820.532745 6836.018747 6782.622375 6746.174808 6726.519939 6745.030775 6756.850707 6774.884711 6776.432287
7 7 1.0 0.7 16337.894141 10332.346436 7000.306648 7497.992465 7431.950940 5502.729419 5524.911924 ... 6482.885052 6533.648902 6532.246667 6513.838768 6498.485707 6497.445762 6483.066188 6481.863696 6480.960683 6479.876117

8 rows × 103 columns

[17]:
info_splitcv
[17]:
{'best_params_regr': {'max_depth': 6},
 'best_params_ensemble': {'random_state': 0,
  'n_estimators': 50,
  'max_features': 1.0,
  'max_samples': 0.7},
 'best_n_estimators': 45,
 'best_params_index': 3,
 'best_score': 4033.689260470907,
 'split_params': {'index_train': array([ 61,  92, 112,   2, 141,  43,  10,  60, 116, 144, 119, 108,  69,
         135,  56,  80, 123, 133, 106, 146,  50, 147,  85,  30, 101,  94,
          64,  89,  91, 125,  48,  13, 111,  95,  20,  15,  52,   3, 149,
          98,   6,  68, 109,  96,  12, 102, 120, 104, 128,  46,  11, 110,
         124,  41, 148,   1, 113, 139,  42,   4, 129,  17,  38,   5,  53,
         143, 105,   0,  34,  28,  55,  75,  35,  23,  74,  31, 118,  57,
         131,  65,  32, 138,  14, 122,  19,  29, 130,  49, 136,  99,  82,
          79, 115, 145,  72,  77,  25,  81, 140, 142,  39,  58,  88,  70,
          87,  36,  21,   9, 103,  67, 117,  47]),
  'index_val': array([114,  62,  33, 107,   7, 100,  40,  86,  76,  71, 134,  51,  73,
          54,  63,  37,  78,  90,  45,  16, 121,  66,  24,   8, 126,  22,
          44,  97,  93,  26, 137,  84,  27, 127, 132,  59,  18,  83]),
  'test_size': 0.25,
  'random_state': 0}}
[18]:
res_splitcv.iloc[info_ecv['best_params_index']]['risk_test-{}'.format(info_splitcv['best_n_estimators'])]
[18]:
6498.48570671178

KFoldCV#

[19]:
from sklearn_ensemble_cv import KFoldCV
res_kfoldcv, info_kfoldcv = KFoldCV(
        X_train, y_train, DecisionTreeRegressor, grid_regr, grid_ensemble,
        M=50, return_df=True, X_test=X_test, Y_test=y_test,
        shuffle=True, random_state=0, n_splits=5,
        )
[20]:
res_kfoldcv
[20]:
max_depth max_features max_samples risk_val-1 risk_val-2 risk_val-3 risk_val-4 risk_val-5 risk_val-6 risk_val-7 ... risk_test-41 risk_test-42 risk_test-43 risk_test-44 risk_test-45 risk_test-46 risk_test-47 risk_test-48 risk_test-49 risk_test-50
0 6 0.9 0.6 36545.771967 27938.257395 24134.373249 22484.269302 22414.000485 21950.847904 21517.275118 ... 20716.631252 20736.461185 20741.668610 20771.880131 20764.847922 20724.517323 20632.514598 20582.426928 20509.866171 20488.539988
1 6 0.9 0.7 36177.830350 26909.814128 23494.918775 22117.966420 22352.988859 21574.710161 21334.398153 ... 20971.227017 21038.568944 21057.930222 21112.938287 21121.274454 21094.416399 20974.712606 20909.166909 20803.096906 20785.131017
2 6 1.0 0.6 18292.692037 12494.399173 10534.652350 9255.066365 9015.318876 8600.248786 8310.603633 ... 7488.364379 7465.797025 7438.536378 7406.623385 7395.788712 7369.921643 7363.239896 7343.266942 7323.823346 7306.897859
3 6 1.0 0.7 17504.625701 11988.718580 10045.440886 9310.529660 8943.994034 7881.303377 7796.536888 ... 7012.628902 7010.448402 7009.868242 6982.805845 6986.106436 6954.242473 6984.252017 6969.651320 6949.117880 6942.209406
4 7 0.9 0.6 37513.970568 29140.205014 25235.190952 22746.825360 22700.608988 22236.588719 21434.062802 ... 20968.167706 20964.003785 20965.059536 20986.504152 20976.639000 20920.110290 20822.718078 20756.068356 20662.774588 20631.524679
5 7 0.9 0.7 36535.534994 27236.065893 23964.159703 22014.460201 22375.658196 21824.093767 21351.289456 ... 20782.401137 20835.509493 20855.989334 20912.039238 20922.025454 20889.501766 20778.093847 20716.453342 20617.880339 20600.168542
6 7 1.0 0.6 18443.271257 12692.367011 10742.179078 9576.634688 9093.749373 8836.020076 8483.320394 ... 7589.567231 7566.403959 7537.413573 7499.722849 7483.501442 7445.006633 7439.014995 7414.823873 7390.486999 7370.135227
7 7 1.0 0.7 17611.199835 11924.784208 9890.939393 9477.430529 9118.334376 8030.870250 7867.956759 ... 7056.478013 7059.444686 7056.407264 7029.778004 7034.609903 6996.837311 7029.396458 7016.586506 6998.825844 6992.662602

8 rows × 103 columns

[21]:
info_kfoldcv
[21]:
{'best_params_regr': {'max_depth': 7},
 'best_params_ensemble': {'random_state': 0,
  'n_estimators': 50,
  'max_features': 1.0,
  'max_samples': 0.7},
 'best_n_estimators': 44,
 'best_params_index': 7,
 'best_score': 6459.598772647386,
 'val_score': array([[[35063.34458322, 35939.07984004, 36684.14739511, 33715.27011173,
          41327.01790531],
         [24675.57954741, 28029.25562499, 27811.81751007, 24437.13036382,
          34737.50392844],
         [21640.19294821, 28001.08434891, 22413.66188661, 20742.76431037,
          27874.16275094],
         ...,
         [16981.41315783, 18983.02155504, 18701.58282469, 18204.59547867,
          24964.4574873 ],
         [16819.75871394, 18905.14712851, 18713.80627924, 18266.68261893,
          25021.57181105],
         [16804.6149917 , 18920.61786553, 18700.8619768 , 18242.83754872,
          25033.04059112]],

        [[34440.82981543, 35056.31200754, 35505.50392758, 34808.60905139,
          41077.89694803],
         [24856.84534113, 26796.49263015, 25459.79869853, 25267.5821653 ,
          32168.35180429],
         [21970.91883632, 28632.83622643, 20011.175936  , 22904.66200669,
          23955.00086764],
         ...,
         [16469.91706412, 17942.85842715, 16928.76174278, 18536.79035297,
          23559.56438967],
         [16314.7960805 , 17827.03033559, 16846.70733556, 18625.29051723,
          23560.8491443 ],
         [16230.31923536, 17848.60512394, 16862.71966614, 18627.56571855,
          23600.0783045 ]],

        [[14698.7607123 , 19341.20118293, 18663.691853  , 19608.60182531,
          19151.20461318],
         [ 9288.57093611, 13245.04788852, 12769.69250122, 14245.29995461,
          12923.3845862 ],
         [ 6967.89128802, 11779.06474754,  8210.58320761, 12967.81347843,
          12747.90902751],
         ...,
         [ 3999.5843062 ,  7933.3185448 ,  6475.28443968,  8695.61301957,
           7073.92536648],
         [ 3978.92328013,  7941.30871   ,  6383.15873455,  8696.82346999,
           7083.61387365],
         [ 3971.84034026,  7937.66491072,  6353.00940483,  8658.51429744,
           7105.63634097]],

        ...,

        [[35279.50745908, 35809.10824398, 35949.92883623, 35053.19190269,
          40585.93852865],
         [25669.38534688, 27949.56680589, 25460.08051813, 25648.34119935,
          31452.95559431],
         [23497.4919398 , 28866.55344735, 20151.61426576, 22751.33203451,
          24553.80682678],
         ...,
         [17245.0773532 , 18366.71717206, 17272.26108278, 18879.89081661,
          23616.80098205],
         [17063.60544168, 18292.95972297, 17223.38792004, 18922.1743427 ,
          23654.24178593],
         [16969.53612248, 18327.86059386, 17243.97390408, 18934.74314066,
          23685.57133032]],

        [[14631.53025636, 19715.86684263, 18369.05735446, 20063.44447163,
          19436.457361  ],
         [ 9372.79909338, 13522.47292053, 12386.12328312, 14596.47523879,
          13583.96451801],
         [ 7396.27619918, 11851.73272447,  8132.36053753, 12677.29413055,
          13653.23180052],
         ...,
         [ 4029.26854999,  7919.73712512,  6352.287586  ,  8777.54026565,
           7569.16528422],
         [ 4013.72861198,  7896.63151974,  6274.04071139,  8791.90490089,
           7588.76017146],
         [ 4008.00139913,  7876.31008759,  6240.13676904,  8772.70606893,
           7595.14144891]],

        [[14546.90348333, 18131.09331047, 18628.94328335, 18461.54205193,
          18287.51704787],
         [ 9444.69483056, 12280.33720004, 12839.71992806, 12857.33919818,
          12201.82988316],
         [ 5791.37815957, 10873.95098952, 10854.18612632, 10706.06587174,
          11229.11581557],
         ...,
         [ 4126.46430885,  6665.5932463 ,  6649.71318028,  7554.78217729,
           7638.66922866],
         [ 4147.63391126,  6667.00993029,  6670.14575688,  7607.66497255,
           7682.7515881 ],
         [ 4143.60272463,  6660.87045332,  6690.95061268,  7607.74483093,
           7703.03270678]]]),
 'test_score': array([[[39165.3323162 , 37266.99911136, 39418.30497579, 37515.62699599,
          37466.07541638],
         [30303.71133516, 27951.86146046, 30535.66855725, 28824.88560828,
          29765.7570254 ],
         [25132.58174696, 24640.22250014, 25796.87625411, 25118.50330304,
          26836.21668138],
         ...,
         [20702.89124779, 19597.02847276, 21076.24189243, 20555.68676525,
          20980.28626208],
         [20664.37351521, 19454.9286556 , 20979.4963262 , 20515.62458776,
          20934.90777093],
         [20647.01925582, 19426.8299882 , 20950.47374189, 20510.71859035,
          20907.65836599]],

        [[39662.2040839 , 38064.906745  , 39464.67003986, 37934.15713478,
          37712.09368195],
         [30890.81314418, 29301.12018517, 30621.514568  , 29017.21640016,
          29394.76136548],
         [27944.99800406, 26181.83960348, 27013.59560781, 24762.55218561,
          25363.95540044],
         ...,
         [20705.33509318, 20672.73504786, 21235.58724364, 20631.87760921,
          21300.29955082],
         [20679.58573704, 20511.09458874, 21101.5154004 , 20561.72074741,
          21161.5680556 ],
         [20669.64282724, 20464.19064318, 21078.18116443, 20570.65517223,
          21142.98527777]],

        [[18346.14510059, 17879.89409   , 19674.81513074, 18103.21744695,
          20420.96515376],
         [13037.31663217, 12443.45905168, 13728.15801328, 12978.34486268,
          13972.02330848],
         [11236.82311459, 10592.61004946, 11567.77239084, 10944.68413877,
          12863.05250065],
         ...,
         [ 7034.60443681,  7028.0415416 ,  7505.14988053,  7640.19417952,
           7508.34467066],
         [ 7041.56157935,  7048.19283455,  7465.63284616,  7614.6578218 ,
           7449.07165022],
         [ 7017.70818603,  7038.95800331,  7459.92295706,  7578.08603415,
           7439.8141151 ]],

        ...,

        [[40051.18485523, 37426.36255274, 39565.72017728, 37919.82863512,
          37441.51890511],
         [30941.35103122, 28196.71984235, 30536.48126264, 29048.63836381,
          29496.46098384],
         [28284.67021057, 24954.91035021, 26559.92118039, 24187.57902697,
          25485.33520472],
         ...,
         [20733.25001227, 20060.70134276, 21148.2712261 , 20453.76720747,
          21186.27691938],
         [20704.72105743, 19928.59256627, 20979.32575228, 20408.90551765,
          21067.85680264],
         [20696.68237184, 19896.74715288, 20944.25370244, 20416.18904578,
          21046.97043471]],

        [[18441.05565727, 18053.14017694, 19565.86681603, 18486.53967116,
          20028.1069874 ],
         [12903.63637131, 12617.60682799, 13497.68418668, 13373.4350388 ,
          13628.32180987],
         [11449.10205102, 10618.03584122, 11492.20900414, 11342.586786  ,
          12775.3466727 ],
         ...,
         [ 6882.25526743,  7260.42696962,  7496.62272947,  7862.07617409,
           7572.73822419],
         [ 6880.2566495 ,  7269.92848736,  7430.09224123,  7853.52588492,
           7518.631732  ],
         [ 6855.41795554,  7257.6635736 ,  7419.19615941,  7811.33837634,
           7507.06007013]],

        [[18146.01266838, 17162.05374621, 18841.37500506, 17245.71676602,
          18817.78862521],
         [12968.57558432, 11806.3614859 , 12929.74216542, 11979.4479075 ,
          12854.06854124],
         [11036.70482972,  9878.84622249, 10847.20550266,  9494.79869512,
          11585.03628959],
         ...,
         [ 6814.19036431,  6715.77467929,  7228.70196355,  7145.33550997,
           7178.93001472],
         [ 6817.11369554,  6675.06619066,  7210.78776298,  7158.36602222,
           7132.79554854],
         [ 6809.05570577,  6670.97662815,  7212.36479296,  7138.75929562,
           7132.15658735]]]),
 'split_params': {'n_splits': 5, 'random_state': 0, 'shuffle': True}}
[22]:
res_kfoldcv.iloc[info_ecv['best_params_index']]['risk_test-{}'.format(info_kfoldcv['best_n_estimators'])]
[22]:
7029.77800406061

GCV and CGCV#

GCV estimates#

Generalized cross-validation (GCV) is a method for estimating the prediction error of an estimator. It is commonly used in the context of linear smoother, including ridge regression, lasso, and elastic net. It is a form of cross-validation without data splitting, where the estimated risk is the training error divided by the effective number of degrees of freedom of the estimator.

There are two variants of the naive GCV for ensemble learning:

  • ‘full’: using all observations to estimate the risk

    \[\tilde{R}_M^{gcv} = \frac{\|{{y}}-{{X}}\tilde{\beta}_{M}\|_2^2 / n }{(1 - \tilde{df}_M / n )^2},\]

    which is consistent when \(k=n\) or \(M=\infty\).

  • ‘union’: using the union of training observations to estimate the risk

    \[\tilde{R}_M^{gcv} = \frac{\|{ L}_{I_{1:M}} ({{y}}-{{X}}\tilde{\beta}_{M})\|_2^2 / |I_{1:M}| }{(1 - \tilde{df}_M / |I_{1:M}| )^2},\]

    which is consistent when \(k=n\) or \(M\in\{1,\infty\}\).

[23]:
from sklearn.linear_model import Ridge, Lasso, ElasticNet
from sklearn_ensemble_cv import generate_data

reset_random_seeds(0)
n_samples, n_features = 1000, 800
Sigma, beta0, X_train, y_train, X_test, y_test, _, _ = generate_data(
    n_samples, n_features, coef='random', func='quad', sigma_quad=.1,
    rho_ar1=0., sigma=.5, df=np.inf, n_test=1000,
)

kwargs_regr = {'alpha': 0.01*X_train.shape[0], 'fit_intercept':False}
kwargs_ensemble = {'max_samples': 0.5, 'bootstrap':False}
regr = Ensemble(estimator=Ridge(**kwargs_regr), n_estimators=100, **kwargs_ensemble)
regr = regr.fit(X_train, y_train)
[24]:
df_gcv = regr.compute_gcv_estimate(X_train, y_train, M0=25, type='union', return_df=True)
[25]:
df_risk = regr.compute_risk(X_test, y_test, return_df=True)

CGCV estimates#

The correct GCV estimators are shown to be consistent for arbitrary ensemble sizes:

\[\tilde{R}^{cgcv,\#}_{M} = \underbrace{\frac{\|{{y}}-{{X}}\tilde{\beta}_{M}\|_2^2 / n }{(1 - \tilde{df}_M / n )^2} }_{\tilde{R}_M^{gcv}} - \underbrace{\frac{1}{M} \Biggl\{ \frac{(\tilde{df}_M/n)^2}{(1-\tilde{df}_M/n)^2} \frac{1}{M}\sum_{m \in [M]} \bigg(\frac{n}{|I_m|}-1\bigg) \hat{R}_{m, m}^{\#} \Biggr\}}_{\mathrm{correction}}.\]

There are two options of the risk component estimators:

  • \(\#=\)’full’: using full observations to estimate \(R_{m,m}\).

  • \(\#=\)’ovlp’: using overlapping observations to estimate \(R_{m,m}\).

Intuitively, the former should be more robust as more observations are used.

[26]:
df_cgcv = regr.compute_cgcv_estimate(X_train, y_train, M0 = 25, type='full', return_df=True)
[27]:
plt.plot(df_gcv['M'], df_gcv['estimate'], label='GCV estimate')
plt.plot(df_cgcv['M'], df_cgcv['estimate'], label='CGCV estimate')
plt.plot(df_risk['M'], df_risk['risk'], label='risk')
plt.legend()
plt.show()
../_images/tutorials_basics_45_0.png

GCV and CGCV for tuning#

Below we apply CGCV for tuning ensemble parameters for ridge. The function GCV also supports naive GCV estimates (when corrected=False), though they are inconsistent to prediction risk and are only included for completeness. Therefore, we do not recommend to use naive GCV for tuning.

Both GCV and CGCV are proposed for subagging (when bootstrap=False in kwargs_regr), however, they also provide reasonable estimates for bagging numerically, as we illustrate below.

[28]:
from sklearn_ensemble_cv import GCV

grid_regr = {'alpha': np.array([0.01,0.05])*X_train.shape[0]}
grid_ensemble = {'max_samples': [0.6,0.8]}

res_cgcv, info_cgcv = GCV(
        X_train, y_train, Ridge, grid_regr, grid_ensemble,
        M=100, M0=25, corrected=True, type='full', return_df=True, X_test=X_test, Y_test=y_test,
        )
[29]:
res_cgcv
[29]:
alpha max_samples risk_val-1 risk_val-2 risk_val-3 risk_val-4 risk_val-5 risk_val-6 risk_val-7 risk_val-8 ... risk_test-91 risk_test-92 risk_test-93 risk_test-94 risk_test-95 risk_test-96 risk_test-97 risk_test-98 risk_test-99 risk_test-100
0 10.0 0.6 1.270438 0.942270 0.825168 0.763242 0.731527 0.708762 0.693972 0.688929 ... 0.632743 0.632524 0.632333 0.632490 0.632899 0.633189 0.633581 0.633870 0.634254 0.634421
1 10.0 0.8 1.689871 1.135225 0.944082 0.853987 0.792728 0.749514 0.716791 0.705366 ... 0.572260 0.572334 0.572093 0.571998 0.572339 0.572438 0.572891 0.573057 0.573264 0.573372
2 50.0 0.6 1.134798 0.886978 0.791813 0.744230 0.720062 0.706429 0.693548 0.691763 ... 0.659251 0.659057 0.658901 0.659089 0.659451 0.659735 0.660065 0.660349 0.660731 0.660897
3 50.0 0.8 1.296153 0.943016 0.811787 0.753779 0.714835 0.690649 0.670342 0.666039 ... 0.593066 0.593141 0.592986 0.592990 0.593272 0.593415 0.593756 0.593978 0.594269 0.594402

4 rows × 203 columns

[30]:
info_cgcv
[30]:
{'best_params_regr': {'alpha': 10.0},
 'best_params_ensemble': {'random_state': 0,
  'n_estimators': 100,
  'max_samples': 0.8},
 'best_n_estimators': 100,
 'best_params_index': 1,
 'best_score': 0.583519102039084}
[31]:
res_cgcv.iloc[info_cgcv['best_params_index']]['risk_test-{}'.format(info_cgcv['best_n_estimators'])]
[31]:
0.5733721598232984