MyNixOS website logo
Description

Deep Neural Networks for Survival Analysis Using 'torch'.

Provides deep learning models for right-censored survival data using the 'torch' backend. Supports multiple loss functions, including Cox partial likelihood, L2-penalized Cox, time-dependent Cox, and accelerated failure time (AFT) loss. Offers a formula-based interface, built-in support for cross-validation, hyperparameter tuning, survival curve plotting, and evaluation metrics such as the C-index, Brier score, and integrated Brier score. For methodological details, see Kvamme et al. (2019) <https://www.jmlr.org/papers/v20/18-424.html>.

survdnn

Deep Neural Networks for Survival Analysis Using torch

License:MIT
R-CMD-check


survdnn implements neural network-based models for right-censored survival analysis using the native torch backend in R. It supports multiple loss functions including Cox partial likelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives, as well as time-dependent extension such as Cox-Time. The package provides a formula interface, supports model evaluation using time-dependent metrics (e.g., C-index, Brier score, IBS), cross-validation, and hyperparameter tuning.


Features

  • Formula interface for Surv() ~ . models
  • Modular neural architectures: configurable layers, activations, and losses
  • Built-in survival loss functions:
    • "cox": Cox partial likelihood
    • "cox_l2": penalized Cox
    • "aft": Accelerated Failure Time
    • "coxtime": deep time-dependent Cox (like DeepSurv)
  • Evaluation: C-index, Brier score, Integrated Brier Score (IBS)
  • Model selection with cv_survdnn() and tune_survdnn()
  • Prediction of survival curves via predict() and plot()

Installation

# Install from GitHub
# install.packages("remotes")
remotes::install_github("ielbadisy/survdnn")

# Or clone and install locally
# git clone https://github.com/ielbadisy/survdnn.git
# setwd("survdnn")
# devtools::install()

Quick Example

library(survdnn)
library(survival, quietly = TRUE)
library(ggplot2)

veteran <- survival::veteran

mod <- survdnn(
  Surv(time, status) ~ age + karno + celltype,
  data = veteran,
  hidden = c(32, 16),
  epochs = 100,
  loss = "cox",
  verbose = TRUE
)
## Epoch 50 - Loss: 3.987919
## Epoch 100 - Loss: 3.974391
summary(mod)
## 

## ── Summary of survdnn model ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

## 
## Formula:
##   Surv(time, status) ~ age + karno + celltype
## <environment: 0x5b3739336aa0>
## 
## Model architecture:
##   Hidden layers:  32 : 16 
##   Activation:  relu 
##   Dropout:  0.3 
##   Final loss:  3.974391 
## 
## Training summary:
##   Epochs:  100 
##   Learning rate:  1e-04 
##   Loss function:  cox 
## 
## Data summary:
##   Observations:  137 
##   Predictors:  age, karno, celltypesmallcell, celltypeadeno, celltypelarge 
##   Time range: [ 1, 999 ]
##   Event rate:  93.4%
plot(mod, group_by = "celltype", times = 1:300)

Loss Functions

# Cox partial likelihood
mod1 <- survdnn(
  Surv(time, status) ~ age + karno,
  data = veteran,
  loss = "cox",
  epochs = 100
  )
## Epoch 50 - Loss: 4.216911
## Epoch 100 - Loss: 4.105076
# Accelerated Failure Time
mod2 <- survdnn(
  Surv(time, status) ~ age + karno,
  data = veteran,
  loss = "aft",
  epochs = 100
  )
## Epoch 50 - Loss: 21.136486
## Epoch 100 - Loss: 20.663244
# Deep time-dependent Cox (Coxtime)
mod3 <- survdnn(
  Surv(time, status) ~ age + karno,
  data = veteran,
  loss = "coxtime",
  epochs = 100
  )
## Epoch 50 - Loss: 4.856084
## Epoch 100 - Loss: 5.289982

Cross-Validation

cv_results <- cv_survdnn(
  Surv(time, status) ~ age + karno + celltype,
  data = veteran,
  times = c(30, 90, 180),
  metrics = c("cindex", "ibs"),
  folds = 3,
  hidden = c(16, 8),
  loss = "cox",
  epochs = 100
)
print(cv_results)

Hyperparameter Tuning

grid <- list(
  hidden     = list(c(16), c(32, 16)),
  lr         = c(1e-3),
  activation = c("relu"),
  epochs     = c(100, 300),
  loss       = c("cox", "aft", "coxtime")
  )

tune_res <- tune_survdnn(
  formula = Surv(time, status) ~ age + karno + celltype,
  data = veteran,
  times = c(90, 300),
  metrics = "cindex",
  param_grid = grid,
  folds = 3,
  refit = FALSE,
  return = "summary"
  )
print(tune_res)

Plot Survival Curves

plot(mod1, group_by = "celltype", times = 1:300)

plot(mod1, group_by = "celltype", times = 1:300, plot_mean_only = TRUE)


Documentation

help(package = "survdnn")
?survdnn
?tune_survdnn
?cv_survdnn
?plot.survdnn

Testing

# Run all tests
devtools::test()

Availability

The survdnn R package is available at: https://github.com/ielbadisy/survdnn

The package is currently under submission to CRAN.


Contributions

Contributions, issues, and feature requests are welcome. Open an issue or submit a pull request!


License

MIT © Imad El Badisy.

Metadata

Version

0.6.0

License

Unknown

Platforms (75)

    Darwin
    FreeBSD
    Genode
    GHCJS
    Linux
    MMIXware
    NetBSD
    none
    OpenBSD
    Redox
    Solaris
    WASI
    Windows
Show all
  • aarch64-darwin
  • aarch64-freebsd
  • aarch64-genode
  • aarch64-linux
  • aarch64-netbsd
  • aarch64-none
  • aarch64-windows
  • aarch64_be-none
  • arm-none
  • armv5tel-linux
  • armv6l-linux
  • armv6l-netbsd
  • armv6l-none
  • armv7a-linux
  • armv7a-netbsd
  • armv7l-linux
  • armv7l-netbsd
  • avr-none
  • i686-cygwin
  • i686-freebsd
  • i686-genode
  • i686-linux
  • i686-netbsd
  • i686-none
  • i686-openbsd
  • i686-windows
  • javascript-ghcjs
  • loongarch64-linux
  • m68k-linux
  • m68k-netbsd
  • m68k-none
  • microblaze-linux
  • microblaze-none
  • microblazeel-linux
  • microblazeel-none
  • mips-linux
  • mips-none
  • mips64-linux
  • mips64-none
  • mips64el-linux
  • mipsel-linux
  • mipsel-netbsd
  • mmix-mmixware
  • msp430-none
  • or1k-none
  • powerpc-netbsd
  • powerpc-none
  • powerpc64-linux
  • powerpc64le-linux
  • powerpcle-none
  • riscv32-linux
  • riscv32-netbsd
  • riscv32-none
  • riscv64-linux
  • riscv64-netbsd
  • riscv64-none
  • rx-none
  • s390-linux
  • s390-none
  • s390x-linux
  • s390x-none
  • vc4-none
  • wasm32-wasi
  • wasm64-wasi
  • x86_64-cygwin
  • x86_64-darwin
  • x86_64-freebsd
  • x86_64-genode
  • x86_64-linux
  • x86_64-netbsd
  • x86_64-none
  • x86_64-openbsd
  • x86_64-redox
  • x86_64-solaris
  • x86_64-windows