Saturday, October 30, 2021

Sebastian Pölsterl: scikit-survival 0.16 released

I am proud to announce the release if version 0.16.0 of scikit-survival, The biggest improvement in this release is that you can now change the evaluation metric that is used in estimators’ score method. This is particular useful for hyper-parameter optimization using scikit-learn’s GridSearchCV. You can now use as_concordance_index_ipcw_scorer, as_cumulative_dynamic_auc_scorer, or as_integrated_brier_score_scorer to adjust the score method to your needs. The example below illustrates how to use these in practice.

For a full list of changes in scikit-survival 0.16.0, please see the release notes.

Installation

Pre-built conda packages are available for Linux, macOS, and Windows via

 conda install -c sebp scikit-survival

Alternatively, scikit-survival can be installed from source following these instructions.

Hyper-Parameter Optimization with Alternative Metrics

The code below is also available as a notebook and can directly be executed by clicking

In this example, we are going to use the German Breast Cancer Study Group 2 dataset. We want to fit a Random Survival Forest and optimize it’s max_depth hyper-parameter using scikit-learn’s GridSearchCV.

Let’s begin by loading the data.

import numpy as np
from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import encode_categorical
gbsg_X, gbsg_y = load_gbsg2()
gbsg_X = encode_categorical(gbsg_X)
lower, upper = np.percentile(gbsg_y["time"], [10, 90])
gbsg_times = np.arange(lower, upper + 1)

Next, we create an instance of Random Survival Forest.

from sksurv.ensemble import RandomSurvivalForest
rsf_gbsg = RandomSurvivalForest(random_state=1)

We define that we want to evaluate the performance of each hyper-parameter configuration by 3-fold cross-validation.

from sklearn.model_selection import KFold
cv = KFold(n_splits=3, shuffle=True, random_state=1)

Next, we define the set of hyper-parameters to evaluate. Here, we search for the best value for max_depth between 1 and 10 (excluding). Note that we have to prefix max_depth with estimator__, because we are going to wrap the actual RandomSurvivalForest instance with one of the classes above.

cv_param_grid = {
"estimator__max_depth": np.arange(1, 10, dtype=int),
}

Now, we can put all the pieces together and start searching for the best hyper-parameters that maximize concordance_index_ipcw.

from sklearn.model_selection import GridSearchCV
from sksurv.metrics import as_concordance_index_ipcw_scorer
gcv_cindex = GridSearchCV(
as_concordance_index_ipcw_scorer(rsf_gbsg, tau=gbsg_times[-1]),
param_grid=cv_param_grid,
cv=cv,
).fit(gbsg_X, gbsg_y)

The same process applies when optimizing hyper-parameters to maximize cumulative_dynamic_auc.

from sksurv.metrics import as_cumulative_dynamic_auc_scorer
gcv_iauc = GridSearchCV(
as_cumulative_dynamic_auc_scorer(rsf_gbsg, times=gbsg_times),
param_grid=cv_param_grid,
cv=cv,
).fit(gbsg_X, gbsg_y)

While as_concordance_index_ipcw_scorer and as_cumulative_dynamic_auc_scorer can be used with any estimator, as_integrated_brier_score_scorer is only available for estimators that provide the predict_survival_function method, which includes RandomSurvivalForest. If available, hyper-parameters that maximize the negative intergrated time-dependent Brier score will be selected, because a lower Brier score indicates better performance.

from sksurv.metrics import as_integrated_brier_score_scorer
gcv_ibs = GridSearchCV(
as_integrated_brier_score_scorer(rsf_gbsg, times=gbsg_times),
param_grid=cv_param_grid,
cv=cv,
).fit(gbsg_X, gbsg_y)

Finally, we can visualize the results of the grid search and compare the best performing hyper-parameter configurations (marked with a red dot).

import matplotlib.pyplot as plt
def plot_grid_search_results(gcv, ax, name):
ax.errorbar(
x=gcv.cv_results_["param_estimator__max_depth"].filled(),
y=gcv.cv_results_["mean_test_score"],
yerr=gcv.cv_results_["std_test_score"],
)
ax.plot(
gcv.best_params_["estimator__max_depth"],
gcv.best_score_,
'ro',
)
ax.set_ylabel(name)
ax.yaxis.grid(True)
_, axs = plt.subplots(3, 1, figsize=(6, 6), sharex=True)
axs[-1].set_xlabel("max_depth")
plot_grid_search_results(gcv_cindex, axs[0], "c-index")
plot_grid_search_results(gcv_iauc, axs[1], "iAUC")
plot_grid_search_results(gcv_ibs, axs[2], "$-$IBS")
Results of hyper-parameter optimization.

Results of hyper-parameter optimization.

When optimizing for the concordance index, a high maximum depth works best, whereas the other metrics are best when choosing a maximum depth of 5 and 6, respectively.



from Planet Python
via read more

No comments:

Post a Comment

TestDriven.io: Working with Static and Media Files in Django

This article looks at how to work with static and media files in a Django project, locally and in production. from Planet Python via read...