Ensemble Cross-validation for Multitask Regression#
[ ]:
import sys
sys.path.append('..')
import os
n_cores = int(8)
os.environ["OMP_NUM_THREADS"] = f"{n_cores}"
os.environ["OPENBLAS_NUM_THREADS"] = f"{n_cores}"
os.environ["MKL_NUM_THREADS"] = f"{n_cores}"
os.environ["VECLIB_MAXIMUM_THREADS"] = f"{n_cores}"
os.environ["NUMEXPR_NUM_THREADS"] = f"{n_cores}"
os.environ["NUMBA_CACHE_DIR"]='/tmp/numba_cache'
import numpy as np
from sklearn_ensemble_cv import reset_random_seeds
import matplotlib.pyplot as plt
reset_random_seeds(0)
We make up some fake data for illustration. Below, the response is of dimension 2.
[2]:
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
X, y = make_regression(n_samples=300, n_features=200,
n_informative=5, n_targets=2,
random_state=0, shuffle=False)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
The Ensemble class#
For users who want to have more control of the ensemble predictor, this section introduces lower-level class and object that we use to implement the cross-validation methods. For users who just want to use easy interface functions, you can safely skip this section.
We provide Ensemble a class for ensemble predictor, whose base class is sklearn.ensemble.BaggingRegressor. This means that the usage of Ensemble is basically the same as the latter, except the new class includes several new member functions that we will illustrate below.
Initialize an object#
The initialization of Ensemble class is the same as sklearn.ensemble.BaggingRegressor, where
The base estimator object, whose hyperparameter
kwargs_regris specified when it is initialized. In the following example, we use decision tree as the base estimator.The hyperparameters for building an ensemble, such as
n_estimators,max_samples,max_features, and etc.
[3]:
from sklearn.tree import DecisionTreeRegressor
from sklearn_ensemble_cv import Ensemble
kwargs_regr = {'max_depth': 7}
kwargs_ensemble = {'max_samples': 0.8}
regr = Ensemble(estimator=DecisionTreeRegressor(**kwargs_regr), n_estimators=100, **kwargs_ensemble)
The prediction is of the same dimension as the response.
[4]:
regr.fit(X_train, y_train)
Y_test_hat = regr.predict(X_test)
Y_test_hat.shape
[4]:
(150, 2)
We can also get prediction from individual base estimators.
[5]:
Y_train_hat = regr.predict_individual(X_train)
Y_train_hat.shape
[5]:
(150, 2, 100)
Compute ECV estimate of the prediction risk#
Below, we use function compute_ecv_estimate to estimate the prediction risk for various ensemble size \(M=1,\ldots,100\), only using the first \(M_0=30\) trees.
[6]:
# compute the risk estimate of the ensemble
df_est = regr.compute_ecv_estimate(X_train, y_train, M0=30, return_df=True)
# compute the risk of the ensemble on the test set
df_risk = regr.compute_risk(X_test, y_test, return_df=True)
[7]:
plt.plot(df_est['M'], df_est['estimate'], label='estimate')
plt.plot(df_risk['M'], df_risk['risk'], label='risk')
plt.legend()
plt.show()
ECV for tuning hyperparameters#
For tuning hyperparameters, such as max_samples and max_features for the ensemble predictors, and max_depth and min_samples_leaf for all base predictors, we can make two grids of these two types of tuning hyperparameters respectively.
We recommend using np.array for each parameter, and make sure to use the correct dtypes that sklearn accept. If one want to set some hyperparameter to a fix value, simply provide it in the grid as either a scalar or a list/array with length one.
[8]:
from sklearn_ensemble_cv import ECV
# Hyperparameters for the base regressor
grid_regr = {
'max_depth':np.array([6,7], dtype=int),
}
# Hyperparameters for the ensemble
grid_ensemble = {
'max_features':np.array([0.9,1.]),
'max_samples':np.array([0.6,0.7]),
}
res_ecv, info_ecv = ECV(
X_train, y_train, DecisionTreeRegressor, grid_regr, grid_ensemble,
X_test=X_test, Y_test=y_test,
M=50, M0=25, return_df=True
)
[9]:
info_ecv
[9]:
{'best_params_regr': {'max_depth': 6},
'best_params_ensemble': {'random_state': 0,
'n_estimators': 50,
'max_features': 1.0,
'max_samples': 0.7},
'best_n_estimators': 50,
'best_params_index': 3,
'best_score': 6118.919276518704,
'delta': 0.0,
'M_max': inf,
'best_n_estimators_extrapolate': inf,
'best_score_extrapolate': 5905.261207161817}
[10]:
res_ecv.iloc[info_ecv['best_params_index']]
[10]:
max_depth 6.000000
max_features 1.000000
max_samples 0.700000
risk_val-1 16588.164675
risk_val-2 11246.712941
...
risk_test-46 5698.734060
risk_test-47 5702.237709
risk_test-48 5709.943647
risk_test-49 5712.515747
risk_test-50 5722.970560
Name: 3, Length: 104, dtype: float64
SplitCV#
[11]:
from sklearn_ensemble_cv import splitCV
res_splitcv, info_splitcv = splitCV(
X_train, y_train, DecisionTreeRegressor, grid_regr, grid_ensemble,
M=50, return_df=True, X_test=X_test, Y_test=y_test,
random_state=0, test_size=0.25,
)
[12]:
info_splitcv
[12]:
{'best_params_regr': {'max_depth': 6},
'best_params_ensemble': {'random_state': 0,
'n_estimators': 50,
'max_features': 1.0,
'max_samples': 0.6},
'best_n_estimators': 34,
'best_params_index': 2,
'best_score': 5240.007705341309,
'split_params': {'index_train': array([ 61, 92, 112, 2, 141, 43, 10, 60, 116, 144, 119, 108, 69,
135, 56, 80, 123, 133, 106, 146, 50, 147, 85, 30, 101, 94,
64, 89, 91, 125, 48, 13, 111, 95, 20, 15, 52, 3, 149,
98, 6, 68, 109, 96, 12, 102, 120, 104, 128, 46, 11, 110,
124, 41, 148, 1, 113, 139, 42, 4, 129, 17, 38, 5, 53,
143, 105, 0, 34, 28, 55, 75, 35, 23, 74, 31, 118, 57,
131, 65, 32, 138, 14, 122, 19, 29, 130, 49, 136, 99, 82,
79, 115, 145, 72, 77, 25, 81, 140, 142, 39, 58, 88, 70,
87, 36, 21, 9, 103, 67, 117, 47]),
'index_val': array([114, 62, 33, 107, 7, 100, 40, 86, 76, 71, 134, 51, 73,
54, 63, 37, 78, 90, 45, 16, 121, 66, 24, 8, 126, 22,
44, 97, 93, 26, 137, 84, 27, 127, 132, 59, 18, 83]),
'test_size': 0.25,
'random_state': 0}}
[13]:
res_splitcv.iloc[info_splitcv['best_params_index']]
[13]:
max_depth 6.000000
max_features 1.000000
max_samples 0.600000
risk_val-1 19024.169727
risk_val-2 12610.242080
...
risk_test-46 6833.434198
risk_test-47 6828.288653
risk_test-48 6805.733968
risk_test-49 6763.855064
risk_test-50 6762.709618
Name: 2, Length: 103, dtype: float64
KFoldCV#
[14]:
from sklearn_ensemble_cv import KFoldCV
res_kfoldcv, info_kfoldcv = KFoldCV(
X_train, y_train, DecisionTreeRegressor, grid_regr, grid_ensemble,
M=50, return_df=True, X_test=X_test, Y_test=y_test,
shuffle=True, random_state=0, n_splits=5,
)
[15]:
info_kfoldcv
[15]:
{'best_params_regr': {'max_depth': 7},
'best_params_ensemble': {'random_state': 0,
'n_estimators': 50,
'max_features': 1.0,
'max_samples': 0.7},
'best_n_estimators': 32,
'best_params_index': 7,
'best_score': 6961.841676288075,
'val_score': array([[[23737.87238121, 30132.91179502, 30843.47362515, 28137.92795717,
33399.49306834],
[16257.2078051 , 23794.78292394, 23212.03247225, 21195.09262109,
26952.11363943],
[15443.15947895, 23989.58436141, 21341.04339101, 18529.64918437,
22063.59523662],
...,
[ 8022.52682882, 15997.12131698, 16240.36668888, 14929.25864793,
20074.35581705],
[ 7945.93509948, 15893.45991707, 16158.51048811, 14794.6479552 ,
19989.06648138],
[ 7915.96842979, 15900.1227764 , 16110.65661389, 14796.65783126,
19946.98247012]],
[[24481.34001435, 29673.33521269, 32222.32927357, 27241.45052042,
33134.92006409],
[15711.7815384 , 23216.12546456, 25048.72983354, 21248.74081758,
27209.41439006],
[12635.20123127, 21057.44711756, 21209.48679197, 18787.03171332,
23918.13091356],
...,
[ 7974.93810372, 15453.93285396, 16797.48048029, 14001.66255445,
19674.82960981],
[ 7946.11381175, 15476.37821402, 16674.0622092 , 13940.02449645,
19755.04830542],
[ 7933.53306295, 15521.22240165, 16656.04285777, 13922.96665038,
19758.23704248]],
[[17834.90038988, 18484.86969859, 19725.89938411, 17723.70847693,
22889.15703912],
[11033.04816503, 13442.83009499, 12234.03383196, 12390.75227706,
17467.73887621],
[ 9662.13757626, 12415.0640164 , 9768.83938898, 12364.94933621,
13849.84179978],
...,
[ 4948.15723296, 8208.90150095, 7066.70653642, 7240.17268584,
11049.56331849],
[ 5038.41715699, 8226.58057539, 7082.04588906, 7277.07266303,
11117.36763166],
[ 5027.64783628, 8228.56752723, 7112.59920133, 7281.97797639,
11111.85471114]],
...,
[[24436.96554084, 30482.38199722, 32130.86497289, 27260.84561294,
33250.30190407],
[16313.30409599, 23262.27348057, 24753.61249563, 21272.99472502,
27534.57301314],
[13147.04942379, 21028.55309838, 21526.9248147 , 19052.85808716,
23460.7907502 ],
...,
[ 8341.27316296, 16673.62047056, 16786.55635513, 14129.68198179,
20500.61254862],
[ 8285.91260541, 16658.7187558 , 16682.82164781, 14053.26412991,
20555.74103123],
[ 8258.95011231, 16699.11411562, 16657.23060246, 14042.44075317,
20544.41844142]],
[[17785.79070148, 18423.3444308 , 19801.30884607, 17863.589176 ,
23356.98476974],
[10488.14306894, 13432.64737473, 12443.06071001, 12018.45321479,
17721.79386454],
[ 9091.7984133 , 12415.5270367 , 9680.46657946, 11226.90356785,
15078.14700844],
...,
[ 4803.68593927, 8166.66358048, 7160.54451287, 6998.1261387 ,
11049.5055115 ],
[ 4912.21603099, 8199.13855482, 7167.90876605, 7054.51899049,
11123.79452391],
[ 4902.18466618, 8194.20041064, 7190.81692836, 7063.96557469,
11130.81310689]],
[[16612.10337243, 16844.92151582, 18563.80786054, 17101.40349172,
21061.39086277],
[10612.16538355, 12338.29372206, 12181.21510555, 11644.65965874,
15471.99341264],
[ 9449.66128944, 11157.58843759, 9840.97284127, 9470.77148106,
12998.74467869],
...,
[ 4609.57965452, 7404.22445543, 6900.88721173, 6127.88179131,
10454.4383407 ],
[ 4669.037748 , 7392.75622217, 6873.68592277, 6149.26781745,
10560.5583152 ],
[ 4682.71350579, 7406.42154103, 6869.88905811, 6159.69357508,
10580.64698805]]]),
'test_score': array([[[31639.4110734 , 30979.994316 , 31611.49340033, 30102.25971793,
29406.38375854],
[23915.56108414, 23559.81373648, 24134.94311323, 23475.25133059,
22866.17622594],
[21395.51921547, 21504.6525317 , 22609.23785716, 21209.18709017,
19337.70894863],
...,
[16437.15801911, 16489.23498681, 16693.90115661, 16496.31276469,
16375.9855555 ],
[16340.20462957, 16326.06225061, 16553.26505849, 16437.37871248,
16277.95012799],
[16299.00247387, 16282.11939158, 16504.58760339, 16418.67591469,
16231.83312089]],
[[32015.6621197 , 31054.54450487, 31738.46623086, 30729.93438674,
29726.19220699],
[24416.23769988, 23613.74257736, 24370.93946175, 23963.25289081,
23711.08676758],
[22250.00914775, 20676.61804408, 21470.47730792, 21954.25743257,
20987.84578695],
...,
[16524.2231838 , 16502.49853147, 16518.94814255, 16879.10502119,
16710.61147316],
[16428.85696598, 16390.97559518, 16432.8150019 , 16796.65680258,
16673.68685248],
[16407.97318099, 16357.39037845, 16391.59109261, 16778.03188568,
16673.57713964]],
[[18722.94712641, 17213.90217426, 18847.21173395, 17788.90260328,
19662.98677036],
[12770.77980235, 12080.69911005, 12059.46111822, 12293.29344173,
13922.37505932],
[10966.63085158, 9892.0676446 , 9871.59083448, 10668.29131865,
12542.14161737],
...,
[ 6947.18806981, 7153.18618246, 6632.30468226, 6694.4550466 ,
7828.18529177],
[ 6975.86758482, 7119.73368717, 6669.32047576, 6752.77560701,
7888.49692205],
[ 6957.30230509, 7115.27475563, 6673.78362864, 6737.87911551,
7887.28394481]],
...,
[[32056.58460472, 31221.7739104 , 31735.11211743, 30980.35578765,
29836.69076551],
[24218.40760568, 23789.5311144 , 24271.55225634, 24285.23357579,
23453.27809186],
[22412.08710106, 20456.79694285, 21299.12366854, 22661.35624444,
20973.26525058],
...,
[16454.71734397, 16742.88303755, 16434.45029685, 16977.18749631,
16636.07265785],
[16350.66203124, 16616.73244782, 16334.52910462, 16908.38780043,
16607.64076691],
[16323.87355556, 16572.61278334, 16294.41704234, 16894.15027625,
16609.56306209]],
[[18590.02165745, 17235.44142195, 18814.17965012, 17872.10795913,
19935.66690903],
[12545.50583983, 12158.38785942, 12098.32959322, 12249.16193259,
14111.48611477],
[ 9943.16093455, 9993.06403833, 9404.45346751, 10546.77686627,
12424.84421669],
...,
[ 6700.69555551, 7242.90642577, 6564.40116274, 6734.45672701,
7858.47749108],
[ 6751.38631504, 7207.3236194 , 6617.41886265, 6809.24830116,
7924.36850695],
[ 6742.00439249, 7201.00111724, 6624.13901521, 6796.06455959,
7931.04143302]],
[[17811.69635028, 16268.88554317, 17473.21022016, 18105.67233597,
19193.73570314],
[12147.5968244 , 11517.38271483, 11690.68796231, 12237.1067785 ,
13262.49541927],
[10910.16066836, 9478.67621929, 9623.70061864, 10866.47869636,
11117.81347274],
...,
[ 6361.4771658 , 6903.71230987, 6340.36813463, 6782.33087299,
7768.63764954],
[ 6380.81039283, 6873.57956559, 6310.55853668, 6814.11731302,
7842.22881964],
[ 6396.40918577, 6874.16028665, 6302.52214895, 6819.3908946 ,
7854.62620535]]]),
'split_params': {'n_splits': 5, 'random_state': 0, 'shuffle': True}}
[16]:
res_kfoldcv.iloc[info_kfoldcv['best_params_index']]
[16]:
max_depth 7.000000
max_features 1.000000
max_samples 0.700000
risk_val-1 18036.725421
risk_val-2 12449.665457
...
risk_test-46 6818.256763
risk_test-47 6822.130614
risk_test-48 6831.305227
risk_test-49 6844.258926
risk_test-50 6849.421744
Name: 7, Length: 103, dtype: float64
Based on the test set, ECV gives the lowest mean square error of the cross-validated parameters (risk_test-50).