Sunday, June 28, 2020

Sebastian Pölsterl: scikit-survival 0.13 Released

Today, I released version 0.13.0 of scikit-survival. Most notably, this release adds sksurv.metrics.brier_score and sksurv.metrics.integrated_brier_score, an updated PEP 517/518 compatible build system, and support for scikit-learn 0.23.

For a full list of changes in scikit-survival 0.13.0, please see the release notes.

Pre-built conda packages are available for Linux, macOS, and Windows via

 conda install -c sebp scikit-survival

Alternatively, scikit-survival can be installed from source following these instructions.

The time-dependent Brier score

The time-dependent Brier score is an extension of the mean squared error to right censored data:

$$ \mathrm{BS}^c(t) = \frac{1}{n} \sum_{i=1}^n I(y_i \leq t \land \delta_i = 1) \frac{(0 - \hat{\pi}(t | \mathbf{x}_i))^2}{\hat{G}(y_i)} + I(y_i > t) \frac{(1 - \hat{\pi}(t | \mathbf{x}_i))^2}{\hat{G}(t)} , $$

where $\hat{\pi}(t | \mathbf{x})$ is a model’s predicted probability of remaining event-free up to time point $t$ for feature vector $\mathbf{x}$, and $1/\hat{G}(t)$ is an inverse probability of censoring weight.

The Brier score is often used to assess calibration. If a model predicts a 10% risk of experiencing an event at time $t$, the observed frequency in the data should match this percentage for a well calibrated model. In addition, the Brier score is also a measure of discrimination: whether a model is able to predict risk scores that allow us to correctly determine the order of events. The concordance index is probably the most common measure of discrimination. However, the concordance index disregards the actual values of predicted risk scores – it is a ranking metric – and is unable to tell us anything about calibration.

Let’s consider an example based on data from the German Breast Cancer Study Group 2.

from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import encode_categorical
from sklearn.model_selection import train_test_split
X, y = load_gbsg2()
X = encode_categorical(X)
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y["cens"], random_state=1)

We want to train a model on the training data and assess its discrimination and calibration on the test data. Here, we consider a Random Survival Forest and Cox’s proportional hazards model with elastic-net penalty.

from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxnetSurvivalAnalysis
rsf = RandomSurvivalForest(max_depth=2, random_state=1)
rsf.fit(X_train, y_train)
cph = CoxnetSurvivalAnalysis(l1_ratio=0.99, fit_baseline_model=True)
cph.fit(X_train, y_train)

First, let’s start with discrimination as measured by the concordance index.

rsf_c = rsf.score(X_test, y_test)
cph_c = cph.score(X_test, y_test)

The result indicates that both models perform equally well, achieving a concordance index of 0.688, which is significantly better than a random model with 0.5 concordance index. Unfortunately, it doesn’t help us to decide which model we should choose. So let’s consider the time-dependent Brier score as an alternative, which asses discrimination and calibration.

We first need to determine for which time points $t$ we want to compute the Brier score for. We are going to use a data-driven approach here by selecting all time points between the 10% and 90% percentile of observed time points.

import numpy as np
lower, upper = np.percentile(y["time"], [10, 90])
times = np.arange(lower, upper + 1)

This returns 1690 time points, for which we need to estimate the probability of survival for, which is given by the survival function. Thus, we iterate over the predicted survival functions on the test data and evaluate each at the time points from above.

rsf_surv_prob = np.row_stack([
fn(times)
for fn in rsf.predict_survival_function(X_test, return_array=False)
])
cph_surv_prob = np.row_stack([
fn(times)
for fn in cph.predict_survival_function(X_test)
])

Note that calling predict_survival_function for RandomSurvivalForest with return_array=False requires scikit-survival 0.13.

In addition, we want to have a baseline to tell us how much better our models are from random. A random model would simply predict 0.5 every time.

random_surv_prob = 0.5 * np.ones((y_test.shape[0], times.shape[0]))

Another useful reference is the Kaplan-Meier estimator, that does not consider any features: it estimates a survival function only from y_test. We replicate this estimate for all samples in the test data.

from sksurv.functions import StepFunction
from sksurv.nonparametric import kaplan_meier_estimator
km_func = StepFunction(*kaplan_meier_estimator(y_test["cens"], y_test["time"]))
km_surv_prob = np.tile(km_func(times), (y_test.shape[0], 1))

Instead of comparing calibration across all 1690 time points, we’ll be using the integrated Brier score (IBS) over all time points, which will give us a single number to compare the models by.

from sksurv.metrics import integrated_brier_score
random_brier = integrated_brier_score(y, y_test, random_surv_prob, times)
km_brier = integrated_brier_score(y, y_test, km_surv_prob, times)
rsf_brier = integrated_brier_score(y, y_test, rsf_surv_prob, times)
cph_brier = integrated_brier_score(y, y_test, cph_surv_prob, times)

The results are summarized in the table below:

RSF Coxnet Random Kaplan-Meier
c-index 0.688 0.688 0.500
IBS 0.194 0.188 0.247 0.217

Despite Random Survival Forest and Cox’s proportional hazards model performing equally well in terms of discrimination, there seems to be a notable difference in terms of calibration, with Cox’s proportional hazards model outperforming Random Survival Forest.

As a final note, I want to clarify that the Brier score is only applicable for models that are able to estimate a survival function. Hence, it currently cannot be used with Survival Support Vector Machines.



from Planet Python
via read more

No comments:

Post a Comment

TestDriven.io: Working with Static and Media Files in Django

This article looks at how to work with static and media files in a Django project, locally and in production. from Planet Python via read...