Deep Neural Networks for Survival Analysis with R 'torch'.
survdnn 
Deep Neural Networks for Survival Analysis using R torch
About
survdnn implements neural network-based models for right-censored survival analysis using the native torch backend in R. It supports multiple loss functions including Cox partial likelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives, as well as time-dependent extension such as Cox-Time. The package provides a formula interface, supports model evaluation using time-dependent metrics (C-index, Brier score, IBS), cross-validation, and hyperparameter tuning.
Review status
A methodological paper describing the design, implementation, and evaluation of survdnn is currently under review at The R Journal.
Main features
Formula interface for
Surv() ~ .modelsModular neural architectures: configurable layers, activations, optimizers, and losses
Built-in survival loss functions:
"cox": Cox partial likelihood"cox_l2": penalized Cox"aft": Accelerated Failure Time"coxtime": deep time-dependent Cox
Evaluation: C-index, Brier score, IBS
Model selection with
cv_survdnn()andtune_survdnn()Prediction of survival curves via
predict()andplot()
Installation
# Install from CRAN
install.packages("survdnn")
# Install from GitHub
install.packages("remotes")
remotes::install_github("ielbadisy/survdnn")
# Or clone and install locally
git clone https://github.com/ielbadisy/survdnn.git
setwd("survdnn")
devtools::install()
Quick example
library(survdnn)
library(survival, quietly = TRUE)
library(ggplot2)
veteran <- survival::veteran
mod <- survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
hidden = c(32, 16),
epochs = 300,
loss = "cox",
verbose = TRUE
)
## Epoch 50 - Loss: 3.967377
##
## Epoch 100 - Loss: 3.863189
##
## Epoch 150 - Loss: 3.879065
##
## Epoch 200 - Loss: 3.814478
##
## Epoch 250 - Loss: 3.756944
##
## Epoch 300 - Loss: 3.823366
summary(mod)
##
## Formula:
## Surv(time, status) ~ age + karno + celltype
## <environment: 0x6171aa19de98>
##
## Model architecture:
## Hidden layers: 32 : 16
## Activation: relu
## Dropout: 0.3
## Final loss: 3.823366
##
## Training summary:
## Epochs: 300
## Learning rate: 1e-04
## Loss function: cox
## Optimizer: adam
##
## Data summary:
## Observations: 137
## Predictors: age, karno, celltypesmallcell, celltypeadeno, celltypelarge
## Time range: [ 1, 999 ]
## Event rate: 93.4%
plot(mod, group_by = "celltype", times = 1:300)
Loss Functions
- Cox partial likelihood
mod1 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "cox",
epochs = 300
)
## Epoch 50 - Loss: 3.988259
##
## Epoch 100 - Loss: 3.930287
##
## Epoch 150 - Loss: 3.913787
##
## Epoch 200 - Loss: 3.896528
##
## Epoch 250 - Loss: 3.819792
##
## Epoch 300 - Loss: 3.893889
- Accelerated Failure Time
mod2 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "aft",
epochs = 300
)
## Epoch 50 - Loss: 16.911470
##
## Epoch 100 - Loss: 16.589067
##
## Epoch 150 - Loss: 16.226612
##
## Epoch 200 - Loss: 15.959708
##
## Epoch 250 - Loss: 15.182121
##
## Epoch 300 - Loss: 15.049762
- Coxtime
mod3 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "coxtime",
epochs = 300
)
## Epoch 50 - Loss: 4.888907
##
## Epoch 100 - Loss: 4.846722
##
## Epoch 150 - Loss: 4.838490
##
## Epoch 200 - Loss: 4.816662
##
## Epoch 250 - Loss: 4.780379
##
## Epoch 300 - Loss: 4.756117
Cross-validation
cv_results <- cv_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(600),
metrics = c("cindex", "ibs"),
folds = 3,
hidden = c(16, 8),
loss = "cox",
epochs = 300
)
print(cv_results)
Hyperparameter tuning
grid <- list(
hidden = list(c(16), c(32, 16)),
lr = c(1e-3),
activation = c("relu"),
epochs = c(100, 300),
loss = c("cox", "aft", "coxtime")
)
tune_res <- tune_survdnn(
formula = Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(90, 300),
metrics = "cindex",
param_grid = grid,
folds = 3,
refit = FALSE,
return = "summary"
)
print(tune_res)
Tuning and refitting the best Model
tune_survdnn() can be used also to automatically refit the best-performing model on the full dataset. This behavior is controlled by the refit and return arguments. For example:
best_model <- tune_survdnn(
formula = Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(90, 300),
metrics = "cindex",
param_grid = grid,
folds = 3,
refit = TRUE,
return = "best_model"
)
In this mode, cross-validation is used to select the optimal hyperparameter configuration, after which the selected model is refitted on the full dataset. The function then returns a fitted object of class "survdnn".
The resulting model can be used directly for prediction visualization, and evaluation:
summary(best_model)
plot(best_model, times = 1:300)
predict(best_model, veteran, type = "risk", times = 180)
This makes tune_survdnn() suitable for end-to-end workflows, combining model selection and final model fitting.
Plot survival curves
plot(mod1, group_by = "celltype", times = 1:300)
plot(mod1, group_by = "celltype", times = 1:300, plot_mean_only = TRUE)
Documentation
help(package = "survdnn")
?survdnn
?tune_survdnn
?cv_survdnn
?plot.survdnn
Testing
# run all tests
devtools::test()
Note on reproducibility
By default, {torch} initializes model weights and shuffles minibatches using random draws, so results may differ across runs. Unlike set.seed(), which only controls R’s random number generator, {torch} relies on its own RNG implemented in C++ (and CUDA when using GPUs).
To ensure reproducibility, random seeds must therefore be set at the Torch level as well.
survdnn provides built-in control of randomness to guarantee reproducible results across runs. The main fitting function, survdnn(), exposes a dedicated .seed argument:
mod <- survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
epochs = 300,
.seed = 123
)
When .seed is provided, survdnn() internally synchronizes both R and Torch random number generators via survdnn_set_seed(), ensuring reproducible:
weight initialization
dropout behavior
minibatch ordering
loss trajectories
If .seed = NULL (the default), randomness is left uncontrolled and results may vary between runs.
For full reproducibility in cross-validation or hyperparameter tuning, the same .seed mechanism is propagated internally by cv_survdnn() and tune_survdnn(), ensuring consistent data splits, model initialization, and optimization paths across repetitions.
CPU and core usage
survdnn relies on the {torch} backend for numerical computation. The number of CPU cores (threads) used during training, prediction, and evaluation is controlled globally by Torch.
By default, Torch automatically configures its CPU thread pools based on the available system resources, unless explicitly overridden by the user using:
torch::torch_set_num_threads(4)
This setting affects:
model training
prediction
evaluation metrics
cross-validation and hyperparameter tuning
GPU acceleration can be enabled by setting .device = "cuda" when calling survdnn() (cv_survdnn() and tune_survdnn() too).
Availability
The survdnn R package is available on CRAN or github
Contributions
Contributions, issues, and feature requests are welcome!
Open an issue or submit a pull request.
License
MIT License © 2025 Imad EL BADISY.