Description
Interpretable Survival Machine Learning Framework.
Description
A modular toolkit for interpretable survival machine learning with a unified interface for fitting, prediction, evaluation, and interpretation. It includes semiparametric, parametric, tree-based, ensemble, boosting, kernel, and deep-learning survival learners, together with benchmarking, scoring, calibration, and model-agnostic interpretation utilities. Representative methodological anchors include Cox (1972) <doi:10.1111/j.2517-6161.1972.tb00899.x>, Royston and Parmar (2002) <doi:10.1002/sim.1203>, Ishwaran et al. (2008) <doi:10.1214/08-AOAS169>, Jaeger et al. (2019) <doi:10.1214/19-AOAS1261>, Harrell et al. (1982) <doi:10.1001/jama.1982.03320430047030>, Graf et al. (1999) <doi:10.1002/(SICI)1097-0258(19990915/30)18:17/18%3C2529::AID-SIM274%3E3.0.CO;2-5>, Friedman (2001) <doi:10.1214/aos/1013203451>, Apley and Zhu (2020) <doi:10.1111/rssb.12377>, and Lundberg and Lee (2017) <https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions>, and other related methods for survival modeling, prediction, and interpretation.
README.md
Survalis: Unified Survival Machine Learning and Interpretability in R
survalis provides a unified framework for survival machine learning survival analysis in R. It supports a wide range of learners, evaluation metrics, cross-validation and interpretability methods.
Installation
# Install from GitHub
remotes::install_github("ielbadisy/survalis")
# Or from source
install.packages("survalis_0.7.0.tar.gz", repos = NULL, type = "source")
Core philosophy
- Consistent function patterns:
fit_*(),predict_*(),tune_*() - Learners return standardized
mlsurv_modelobjects - Predictions return
data.frameof survival probabilities:t=100,t=200, … - Evaluation is fully modular: plug any
fit_*/predict_*withcv_survlearner()orscore_survmodel() - Designed for interpretability post prediction.
Exploring the package
List all available survival learners
library(survalis)
# See all available learners
list_survlearners()
#> # A tibble: 19 × 8
#> learner fit predict tune has_fit has_predict has_tune available
#> <chr> <chr> <chr> <chr> <lgl> <lgl> <lgl> <lgl>
#> 1 coxph fit_cox… predic… <NA> TRUE TRUE FALSE TRUE
#> 2 aalen fit_aal… predic… <NA> TRUE TRUE FALSE TRUE
#> 3 glmnet fit_glm… predic… tune… TRUE TRUE TRUE TRUE
#> 4 selectcox fit_sel… predic… tune… TRUE TRUE TRUE TRUE
#> 5 aftgee fit_aft… predic… <NA> TRUE TRUE FALSE TRUE
#> 6 flexsurvreg fit_fle… predic… tune… TRUE TRUE TRUE TRUE
#> 7 stpm2 fit_stp… predic… <NA> TRUE TRUE FALSE TRUE
#> 8 bnnsurv fit_bnn… predic… tune… TRUE TRUE TRUE TRUE
#> 9 rpart fit_rpa… predic… tune… TRUE TRUE TRUE TRUE
#> 10 bart fit_bart predic… tune… TRUE TRUE TRUE TRUE
#> 11 xgboost fit_xgb… predic… tune… TRUE TRUE TRUE TRUE
#> 12 ranger fit_ran… predic… tune… TRUE TRUE TRUE TRUE
#> 13 rsf fit_rsf predic… tune… TRUE TRUE TRUE TRUE
#> 14 cforest fit_cfo… predic… tune… TRUE TRUE TRUE TRUE
#> 15 blackboost fit_bla… predic… tune… TRUE TRUE TRUE TRUE
#> 16 survsvm fit_sur… predic… tune… TRUE TRUE TRUE TRUE
#> 17 survdnn fit_sur… predic… tune… TRUE TRUE TRUE TRUE
#> 18 orsf fit_orsf predic… tune… TRUE TRUE TRUE TRUE
#> 19 survmetalearner fit_sur… predic… <NA> TRUE TRUE FALSE TRUE
# See only tunable learners (those with a tune_* function)
list_survlearners(has_tune = TRUE)
#> # A tibble: 14 × 8
#> learner fit predict tune has_fit has_predict has_tune available
#> <chr> <chr> <chr> <chr> <lgl> <lgl> <lgl> <lgl>
#> 1 glmnet fit_glmnet predic… tune… TRUE TRUE TRUE TRUE
#> 2 selectcox fit_selectc… predic… tune… TRUE TRUE TRUE TRUE
#> 3 flexsurvreg fit_flexsur… predic… tune… TRUE TRUE TRUE TRUE
#> 4 bnnsurv fit_bnnsurv predic… tune… TRUE TRUE TRUE TRUE
#> 5 rpart fit_rpart predic… tune… TRUE TRUE TRUE TRUE
#> 6 bart fit_bart predic… tune… TRUE TRUE TRUE TRUE
#> 7 xgboost fit_xgboost predic… tune… TRUE TRUE TRUE TRUE
#> 8 ranger fit_ranger predic… tune… TRUE TRUE TRUE TRUE
#> 9 rsf fit_rsf predic… tune… TRUE TRUE TRUE TRUE
#> 10 cforest fit_cforest predic… tune… TRUE TRUE TRUE TRUE
#> 11 blackboost fit_blackbo… predic… tune… TRUE TRUE TRUE TRUE
#> 12 survsvm fit_survsvm predic… tune… TRUE TRUE TRUE TRUE
#> 13 survdnn fit_survdnn predic… tune… TRUE TRUE TRUE TRUE
#> 14 orsf fit_orsf predic… tune… TRUE TRUE TRUE TRUE
# Shortcut for tunable learners
list_tunable_survlearners()
#> # A tibble: 14 × 8
#> learner fit predict tune has_fit has_predict has_tune available
#> <chr> <chr> <chr> <chr> <lgl> <lgl> <lgl> <lgl>
#> 1 glmnet fit_glmnet predic… tune… TRUE TRUE TRUE TRUE
#> 2 selectcox fit_selectc… predic… tune… TRUE TRUE TRUE TRUE
#> 3 flexsurvreg fit_flexsur… predic… tune… TRUE TRUE TRUE TRUE
#> 4 bnnsurv fit_bnnsurv predic… tune… TRUE TRUE TRUE TRUE
#> 5 rpart fit_rpart predic… tune… TRUE TRUE TRUE TRUE
#> 6 bart fit_bart predic… tune… TRUE TRUE TRUE TRUE
#> 7 xgboost fit_xgboost predic… tune… TRUE TRUE TRUE TRUE
#> 8 ranger fit_ranger predic… tune… TRUE TRUE TRUE TRUE
#> 9 rsf fit_rsf predic… tune… TRUE TRUE TRUE TRUE
#> 10 cforest fit_cforest predic… tune… TRUE TRUE TRUE TRUE
#> 11 blackboost fit_blackbo… predic… tune… TRUE TRUE TRUE TRUE
#> 12 survsvm fit_survsvm predic… tune… TRUE TRUE TRUE TRUE
#> 13 survdnn fit_survdnn predic… tune… TRUE TRUE TRUE TRUE
#> 14 orsf fit_orsf predic… tune… TRUE TRUE TRUE TRUE
List interpretability tools
# List available interpretability methods
list_interpretability_methods()
#> # A tibble: 8 × 4
#> compute plot has_compute has_plot
#> <chr> <chr> <lgl> <lgl>
#> 1 compute_shap plot_shap TRUE TRUE
#> 2 compute_pdp plot_pdp TRUE TRUE
#> 3 compute_ale plot_ale TRUE TRUE
#> 4 compute_surrogate plot_surrogate TRUE TRUE
#> 5 compute_tree_surrogate plot_tree_surrogate TRUE TRUE
#> 6 compute_varimp plot_varimp TRUE TRUE
#> 7 compute_interactions plot_interactions TRUE TRUE
#> 8 compute_counterfactual <NA> TRUE FALSE
# Show which compute_* methods have a plot_* counterpart
subset(list_interpretability_methods(), !is.na(plot))
#> # A tibble: 7 × 4
#> compute plot has_compute has_plot
#> <chr> <chr> <lgl> <lgl>
#> 1 compute_shap plot_shap TRUE TRUE
#> 2 compute_pdp plot_pdp TRUE TRUE
#> 3 compute_ale plot_ale TRUE TRUE
#> 4 compute_surrogate plot_surrogate TRUE TRUE
#> 5 compute_tree_surrogate plot_tree_surrogate TRUE TRUE
#> 6 compute_varimp plot_varimp TRUE TRUE
#> 7 compute_interactions plot_interactions TRUE TRUE
List evaluation metrics
# List available metrics used in cross-validation and scoring
list_metrics()
#> # A tibble: 3 × 4
#> metric direction summary range
#> <chr> <chr> <chr> <chr>
#> 1 cindex maximize Harrell-style concordance index for survival predictio… [0, …
#> 2 brier minimize Brier Score at specified evaluation time(s) (IPCW-weig… [0, …
#> 3 ibs minimize Integrated Brier Score over an evaluation time grid (I… [0, …
Basic Workflow
1. Fit a model
mod_cox <- fit_coxph(Surv(time, status) ~ age + karno + celltype, data = veteran)
summary(mod_cox)
#>
#> ── coxph summary ───────────────────────────────────────────────────────────────
#> Formula:
#> Surv(time, status) ~ age + karno + celltype
#> Engine: survival
#> Learner: coxph
#> Data summary:
#> - Observations: 137
#> - Predictors: "age, karno, celltypesmallcell, celltypeadeno, celltypelarge"
#> - Time range: [1, 999]
#> - Event rate: "93.4%"
2. Predict survival probabilities
pred <- predict_coxph(mod_cox, newdata = veteran[1:5, ], times = c(100, 200))
pred
#> t=100 t=200
#> 1 0.6142681 0.3541697
#> 2 0.6944383 0.4599242
#> 3 0.5556797 0.2860796
#> 4 0.6033305 0.3408724
#> 5 0.6959633 0.4620783
3. Evaluate model performance
Direct evalution (single split):
score <- score_survmodel(mod_cox, times = c(100, 200), metrics = c("cindex", "ibs"))
score
#> # A tibble: 2 × 2
#> metric value
#> <chr> <dbl>
#> 1 cindex 0.734
#> 2 ibs 0.160
cv_res <- cv_survlearner(
Surv(time, status) ~ age + karno + celltype,
veteran,
fit_coxph,
predict_coxph,
times = 80,
metrics = c("cindex", "ibs"),
folds = 5,
seed = 123,
verbose = FALSE
)
cv_res
#> # A tibble: 10 × 5
#> splits id fold metric value
#> <list> <chr> <int> <chr> <dbl>
#> 1 <split [109/28]> Fold1 1 cindex 0.699
#> 2 <split [109/28]> Fold1 1 ibs 0.227
#> 3 <split [109/28]> Fold2 2 cindex 0.812
#> 4 <split [109/28]> Fold2 2 ibs 0.141
#> 5 <split [110/27]> Fold3 3 cindex 0.695
#> 6 <split [110/27]> Fold3 3 ibs 0.217
#> 7 <split [110/27]> Fold4 4 cindex 0.698
#> 8 <split [110/27]> Fold4 4 ibs 0.188
#> 9 <split [110/27]> Fold5 5 cindex 0.688
#> 10 <split [110/27]> Fold5 5 ibs 0.138
cv_summary(cv_res)
#> # A tibble: 2 × 7
#> metric mean sd n se lower upper
#> <chr> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
#> 1 cindex 0.718 0.0528 5 0.0236 0.672 0.765
#> 2 ibs 0.182 0.0417 5 0.0186 0.146 0.219
4. Visualize interpretation
shap_meanabs <- compute_shap(
model = mod_cox,
newdata = veteran[100,],
baseline_data = veteran,
times = 80,
sample.size = 50,
aggregate = TRUE,
method = "meanabs"
)
shap_meanabs
#> feature phi
#> trt trt 0.000000000
#> celltype celltype 0.005850095
#> karno karno 0.045326970
#> diagtime diagtime 0.000000000
#> age age 0.002540069
#> prior prior 0.000000000
plot_shap(shap_meanabs)

More interpretability methods
survalis also provides PDP, ALE, surrogate explanations, tree surrogates, permutation importance, interaction analysis, and counterfactuals.
Partial dependence and ICE
pdp_age <- compute_pdp(
model = mod_cox,
data = veteran,
feature = "age",
times = c(100, 200, 300),
method = "pdp+ice"
)
plot_pdp(pdp_age, feature = "age", which = "per_time")
#> Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
#> ℹ Please use tidy evaluation idioms with `aes()`.
#> ℹ See also `vignette("ggplot2-in-packages")` for more information.
#> ℹ The deprecated feature was likely used in the survalis package.
#> Please report the issue to the authors.
#> This warning is displayed once per session.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.
#> Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
#> ℹ Please use `linewidth` instead.
#> ℹ The deprecated feature was likely used in the survalis package.
#> Please report the issue to the authors.
#> This warning is displayed once per session.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.

plot_pdp(pdp_age, feature = "age", which = "integrated", smooth = TRUE)
#> `geom_smooth()` using formula = 'y ~ x'

Accumulated local effects
ale_karno <- compute_ale(
model = mod_cox,
newdata = veteran,
feature = "karno",
times = c(100, 200, 300)
)
plot_ale(ale_karno, feature = "karno", which = "per_time")

plot_ale(ale_karno, feature = "karno", which = "integrated", smooth = TRUE)
#> `geom_smooth()` using formula = 'y ~ x'

Local surrogate explanation
local_surrogate <- compute_surrogate(
model = mod_cox,
newdata = veteran[1, , drop = FALSE],
baseline_data = veteran,
times = c(100, 200, 300),
target_time = 200,
k = 5
)
local_surrogate
#> feature feature_value effect target_time
#> 1 karno 60 0.491034890 200
#> 2 celltype squamous 0.189632633 200
#> 3 age 69 0.120729843 200
#> 4 diagtime 7 0.001800378 200
#> 5 prior 0 0.000000000 200
plot_surrogate(local_surrogate, top_n = 10)

Tree surrogate
tree_surrogate <- compute_tree_surrogate(
model = mod_cox,
data = veteran,
times = c(100, 200, 300)
)
plot_tree_surrogate(tree_surrogate, type = "importance", top_n = 5)

# plot_tree_surrogate(tree_surrogate, type = "tree")
Permutation variable importance
varimp_res <- compute_varimp(
model = mod_cox,
times = c(100, 200, 300),
metric = "ibs",
n_repetitions = 5,
seed = 123
)
varimp_res
#> # A tibble: 6 × 5
#> feature importance importance_05 importance_95 scaled_importance
#> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 karno 0.0624 0.192 0.223 100
#> 2 celltype 0.0446 0.185 0.201 71.5
#> 3 age -0.00180 0.144 0.146 2.89
#> 4 trt 0 0.147 0.147 0
#> 5 diagtime 0 0.147 0.147 0
#> 6 prior 0 0.147 0.147 0
plot_varimp(varimp_res)

Feature interactions
interaction_1way <- compute_interactions(
model = mod_cox,
data = veteran,
times = c(100, 200, 300),
target_time = 200,
type = "1way"
)
interaction_heatmap <- compute_interactions(
model = mod_cox,
data = veteran,
times = c(100, 200, 300),
target_time = 200,
type = "heatmap"
)
interaction_time <- compute_interactions(
model = mod_cox,
data = veteran,
times = c(100, 200, 300),
type = "time"
)
plot_interactions(interaction_1way, type = "1way")

plot_interactions(interaction_heatmap, type = "heatmap")

plot_interactions(interaction_time, type = "time")

Counterfactual explanations
counterfactuals <- compute_counterfactual(
model = mod_cox,
newdata = veteran[1, , drop = FALSE],
times = c(100, 200, 300),
target_time = 200,
features_to_change = c("age", "karno", "diagtime"),
cost_penalty = 0.01
)
counterfactuals
#> feature original_value suggested_value survival_gain change_cost
#> 1 karno 60 81.0202 0.2347 21.0202
#> 2 diagtime 7 7.0808 0.0000 0.0808
#> 3 age 69 69.1313 0.0003 0.1313
#> penalized_gain
#> 1 0.0245
#> 2 -0.0008
#> 3 -0.0010
5. Calibration
compute_calibration(
model = mod_cox, data = veteran,
time = "time", status = "status",
eval_time = 80, n_bins = 10, n_boot = 30
) |> plot_calibration()

Citing
citation("survalis")