Deep Neural Networks for Survival Analysis Using 'torch'.
survdnn 
Deep Neural Networks for Survival Analysis Using torch
survdnn
implements neural network-based models for right-censored survival analysis using the native torch
backend in R. It supports multiple loss functions including Cox partial likelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives, as well as time-dependent extension such as Cox-Time. The package provides a formula interface, supports model evaluation using time-dependent metrics (e.g., C-index, Brier score, IBS), cross-validation, and hyperparameter tuning.
Features
- Formula interface for
Surv() ~ .
models - Modular neural architectures: configurable layers, activations, and losses
- Built-in survival loss functions:
"cox"
: Cox partial likelihood"cox_l2"
: penalized Cox"aft"
: Accelerated Failure Time"coxtime"
: deep time-dependent Cox (like DeepSurv)
- Evaluation: C-index, Brier score, Integrated Brier Score (IBS)
- Model selection with
cv_survdnn()
andtune_survdnn()
- Prediction of survival curves via
predict()
andplot()
Installation
# Install from GitHub
# install.packages("remotes")
remotes::install_github("ielbadisy/survdnn")
# Or clone and install locally
# git clone https://github.com/ielbadisy/survdnn.git
# setwd("survdnn")
# devtools::install()
Quick Example
library(survdnn)
library(survival, quietly = TRUE)
library(ggplot2)
veteran <- survival::veteran
mod <- survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
hidden = c(32, 16),
epochs = 100,
loss = "cox",
verbose = TRUE
)
## Epoch 50 - Loss: 3.987919
## Epoch 100 - Loss: 3.974391
summary(mod)
##
## ── Summary of survdnn model ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
##
## Formula:
## Surv(time, status) ~ age + karno + celltype
## <environment: 0x5b3739336aa0>
##
## Model architecture:
## Hidden layers: 32 : 16
## Activation: relu
## Dropout: 0.3
## Final loss: 3.974391
##
## Training summary:
## Epochs: 100
## Learning rate: 1e-04
## Loss function: cox
##
## Data summary:
## Observations: 137
## Predictors: age, karno, celltypesmallcell, celltypeadeno, celltypelarge
## Time range: [ 1, 999 ]
## Event rate: 93.4%
plot(mod, group_by = "celltype", times = 1:300)
Loss Functions
# Cox partial likelihood
mod1 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "cox",
epochs = 100
)
## Epoch 50 - Loss: 4.216911
## Epoch 100 - Loss: 4.105076
# Accelerated Failure Time
mod2 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "aft",
epochs = 100
)
## Epoch 50 - Loss: 21.136486
## Epoch 100 - Loss: 20.663244
# Deep time-dependent Cox (Coxtime)
mod3 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "coxtime",
epochs = 100
)
## Epoch 50 - Loss: 4.856084
## Epoch 100 - Loss: 5.289982
Cross-Validation
cv_results <- cv_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(30, 90, 180),
metrics = c("cindex", "ibs"),
folds = 3,
hidden = c(16, 8),
loss = "cox",
epochs = 100
)
print(cv_results)
Hyperparameter Tuning
grid <- list(
hidden = list(c(16), c(32, 16)),
lr = c(1e-3),
activation = c("relu"),
epochs = c(100, 300),
loss = c("cox", "aft", "coxtime")
)
tune_res <- tune_survdnn(
formula = Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(90, 300),
metrics = "cindex",
param_grid = grid,
folds = 3,
refit = FALSE,
return = "summary"
)
print(tune_res)
Plot Survival Curves
plot(mod1, group_by = "celltype", times = 1:300)
plot(mod1, group_by = "celltype", times = 1:300, plot_mean_only = TRUE)
Documentation
help(package = "survdnn")
?survdnn
?tune_survdnn
?cv_survdnn
?plot.survdnn
Testing
# Run all tests
devtools::test()
Availability
The survdnn
R package is available at: https://github.com/ielbadisy/survdnn
The package is currently under submission to CRAN.
Contributions
Contributions, issues, and feature requests are welcome. Open an issue or submit a pull request!
License
MIT © Imad El Badisy.