MyNixOS website logo
Description

Fitting Interpretable Neural Additive Models Using Orthogonalization.

An algorithm for fitting interpretable additive neural networks for identifiable and visualizable feature effects using post hoc orthogonalization. Fit custom neural networks intuitively using established 'R' 'formula' notation, including interaction effects of arbitrary order while preserving identifiability to enable a functional decomposition of the prediction function. For more details see Koehler et al. (2025) <doi:10.1038/s44387-025-00033-7>.

Orthogonal neural additive models for interpretable machine learning by functional decomposition of black-box models into explainable predictor effects

Install package:

If this is the first time using keras or tensorflow in R, you need to run keras3::install_keras(). For certain systems, especially windows systems, install_keras may not succeed. In this case, use install_conda_env to create a conda environment and install relevant python modules.

This readme will guide you through the workflow of specifying, fitting and analysing an orthogonal neural additive model.

Simulate example data

# Create training data
n <- 1000
x1 <- runif(n, -2, 2)
x2 <- runif(n, -2, 2)
x3 <- runif(n, -2, 2)
noise <- rnorm(n, 0, 1)
y <- sin(x1) + dt(x2, 1) * 4 + 2 * x3 +
  x1 * x2 + noise
train <- cbind(x1, x2, x3, y)

Specify model architecture(s)

Different feature (interaction) effects can have different model architectures, varying in type and complexity. These architectures are to be specified in the list_of_deep_models-argument.

simple_model <- function(inputs) {
  outputs <- inputs %>%
    layer_dense(units = 128, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 64, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 32, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 16, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 8, activation = "relu", use_bias = TRUE) %>%
    keras::layer_dense(units = 1, activation = "linear", use_bias = TRUE)
  keras_model(inputs, outputs)
}
complex_model <- function(inputs) {
  outputs <- inputs %>%
    layer_dense(units = 512, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 256, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 128, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 64, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 32, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 16, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 8, activation = "relu", use_bias = TRUE) %>%
    layer_dense(units = 1, activation = "linear", use_bias = TRUE)
    keras_model(inputs, outputs)
}

list_of_deep_models = list(simple = simple_model,
                           complex = complex_model)

Specify effects of interest

The effects to be fitted are supplied in a formula-object. The names of the functions for each feature have to correspond to the names in the list_of_deep_models-argument.

model_formula = y ~ simple(x1) + simple(x2) + simple(x3) + 
  complex(x1, x2) + complex(x1, x2, x3)

Fit onam model

An onam model can then be fitted according to the model formula. An ensembling strategy is used, here with 2 ensemble members. Further model fitting parameters can be specified, such as verbosity, callbacks or number of training steps.

mod <- onam(formula = model_formula, 
            list_of_deep_models = list_of_deep_models,
            data = train, n_ensemble = 2, epochs = 100)

Model evaluation

The fitted model can be investigated regarding goodness of fit and degree of interpretability through the summary()-function.

summary(mod)

Fitted feature effects can be visualized in ggplot-figures:

plot_main_effect(mod, "x1")
plot_inter_effect(mod, "x1", "x2", interpolate = TRUE)
Metadata

Version

1.0.1

License

Unknown

Platforms (78)

    Darwin
    FreeBSD
    Genode
    GHCJS
    Linux
    MMIXware
    NetBSD
    none
    OpenBSD
    Redox
    Solaris
    uefi
    WASI
    Windows
Show all
  • aarch64-darwin
  • aarch64-freebsd
  • aarch64-genode
  • aarch64-linux
  • aarch64-netbsd
  • aarch64-none
  • aarch64-uefi
  • aarch64-windows
  • aarch64_be-none
  • arm-none
  • armv5tel-linux
  • armv6l-linux
  • armv6l-netbsd
  • armv6l-none
  • armv7a-linux
  • armv7a-netbsd
  • armv7l-linux
  • armv7l-netbsd
  • avr-none
  • i686-cygwin
  • i686-freebsd
  • i686-genode
  • i686-linux
  • i686-netbsd
  • i686-none
  • i686-openbsd
  • i686-windows
  • javascript-ghcjs
  • loongarch64-linux
  • m68k-linux
  • m68k-netbsd
  • m68k-none
  • microblaze-linux
  • microblaze-none
  • microblazeel-linux
  • microblazeel-none
  • mips-linux
  • mips-none
  • mips64-linux
  • mips64-none
  • mips64el-linux
  • mipsel-linux
  • mipsel-netbsd
  • mmix-mmixware
  • msp430-none
  • or1k-none
  • powerpc-linux
  • powerpc-netbsd
  • powerpc-none
  • powerpc64-linux
  • powerpc64le-linux
  • powerpcle-none
  • riscv32-linux
  • riscv32-netbsd
  • riscv32-none
  • riscv64-linux
  • riscv64-netbsd
  • riscv64-none
  • rx-none
  • s390-linux
  • s390-none
  • s390x-linux
  • s390x-none
  • vc4-none
  • wasm32-wasi
  • wasm64-wasi
  • x86_64-cygwin
  • x86_64-darwin
  • x86_64-freebsd
  • x86_64-genode
  • x86_64-linux
  • x86_64-netbsd
  • x86_64-none
  • x86_64-openbsd
  • x86_64-redox
  • x86_64-solaris
  • x86_64-uefi
  • x86_64-windows