MyNixOS website logo
Description

Get the Insights of Your Neural Network.

Interpretation methods for analyzing the behavior and individual predictions of modern neural networks in a three-step procedure: Converting the model, running the interpretation method, and visualizing the results. Implemented methods are, e.g., 'Connection Weights' described by Olden et al. (2004) <doi:10.1016/j.ecolmodel.2004.03.013>, layer-wise relevance propagation ('LRP') described by Bach et al. (2015) <doi:10.1371/journal.pone.0130140>, deep learning important features ('DeepLIFT') described by Shrikumar et al. (2017) <arXiv:1704.02685> and gradient-based methods like 'SmoothGrad' described by Smilkov et al. (2017) <arXiv:1706.03825>, 'Gradient x Input' described by Baehrens et al. (2009) <arXiv:0912.1128> or 'Vanilla Gradient'.

innsight - Get the insights of your neural network

R-CMD-check CRANstatus Lifecycle:experimental Codecov testcoverage

Table of contents

Introduction

innsight is an R package that interprets the behavior and explains individual predictions of modern neural networks. Many methods for explaining individual predictions already exist, but hardly any of them are implemented or available in R. Most of these so-called feature attribution methods are only implemented in Python and thus difficult to access or use for the R community. In this sense, the package innsight provides a common interface for various methods for the interpretability of neural networks and can therefore be considered as an R analogue to iNNvestigate or Captum for Python.

This package implements several model-specific interpretability (feature attribution) methods based on neural networks in R, e.g.,

  • Layer-wise Relevance Propagation (LRP)
    • Including propagation rules: $\varepsilon$-rule and $\alpha$-$\beta$-rule
  • Deep Learning Important Features (DeepLift)
    • Including propagation rules for non-linearities: Rescale rule and RevealCancel rule
    • DeepSHAP
  • Gradient-based methods:
  • Connection Weights

Example results for these methods on ImageNet with pretrained network VGG19 (see Example 3: ImageNet with keras for details): vgg16

The package innsight aims to be as flexible as possible and independent of a specific deep learning package in which the passed network has been learned. Basically, a neural network of the libraries torch, keras and neuralnet can be passed, which is internally converted into a torch model with special insights needed for interpretation. But it is also possible to pass an arbitrary net in form of a named list (see vignette for details).

Installation

The package can be installed directly from CRAN and the development version from GitHub with the following commands (successful installation of devtools is required)

# Stable version
install.packages("innsight")

# Development version
devtools::install_github("bips-hb/innsight")

Internally, any passed model is converted to a torch model, thus the correct functionality of this package relies on a complete and correct installation of torch. For this reason, the following command must be run manually to install the missing libraries LibTorch and LibLantern:

torch::install_torch()

📝 Note
Currently this can lead to problems under Windows if the Visual Studio runtime is not pre-installed. See the issue on GitHub here or for more information and other problems with installing torch see the official installation vignette of torch.

Usage

You have a trained neural network model and your model input data data. Now you want to interpret individual data points or the overall behavior by using the methods from the package innsight, then stick to the following pseudo code:

# --------------- Step 0: Train your model -----------------
# 'model' has to be an instance of either torch::nn_sequential, 
# keras::keras_model_sequential, keras::keras_model or neuralnet::neuralnet
model = ...

# -------------- Step 1: Convert your model ----------------
# For keras and neuralnet
converter <- convert(model)
# For a torch model the argument 'input_dim' is required
converter <- convert(model, input_dim = model_input_dim)

# -------------- Step 2: Apply method ----------------------
# Apply global method
result <- run_method(converter) # no data argument is needed
# Apply local methods
result <- run_method(converter, data)

# -------------- Step 3: Get and plot results --------------
# Get the results as an array
res <- get_result(result)
# Plot individual results
plot(result)
# Plot a aggregated plot of all given data points in argument 'data' 
plot_global(result)
boxplot(result) # alias of `plot_global` for tabular and signal data
# Interactive plots can also be created for both methods
plot(result, as_plotly = TRUE)

For a more detailed high-level introduction, see the introduction vignette, and for a full in-depth explanation with all the possibilities, see the “In-depth explanation” vignette.

Examples

  • Iris dataset with torch model (numeric tabular data) → vignette
  • Penguin dataset with torch model and trained with luz (numeric and categorical tabular data) → vignette
  • ImageNet dataset with pre-trained models in keras (image data) → article

Contributing and future work

If you would like to contribute, please open an issue or submit a pull request.

This package becomes even more alive and valuable if people are using it for their analyses. Therefore, don’t hesitate to write me ([email protected]) or create a feature request if you are missing something for your analyses or have great ideas for extending this package. Currently, we are working on the following:

  • [ ] GPU support
  • [ ] More methods, e.g. Grad-CAM, etc.
  • [ ] More examples and documentation (contact me if you have a non-trivial application for me)

Funding

This work is funded by the German Research Foundation (DFG) in the context of the Emmy Noether Grant 437611051.

Metadata

Version

0.3.0

License

Unknown

Platforms (75)

    Darwin
    FreeBSD
    Genode
    GHCJS
    Linux
    MMIXware
    NetBSD
    none
    OpenBSD
    Redox
    Solaris
    WASI
    Windows
Show all
  • aarch64-darwin
  • aarch64-genode
  • aarch64-linux
  • aarch64-netbsd
  • aarch64-none
  • aarch64_be-none
  • arm-none
  • armv5tel-linux
  • armv6l-linux
  • armv6l-netbsd
  • armv6l-none
  • armv7a-darwin
  • armv7a-linux
  • armv7a-netbsd
  • armv7l-linux
  • armv7l-netbsd
  • avr-none
  • i686-cygwin
  • i686-darwin
  • i686-freebsd
  • i686-genode
  • i686-linux
  • i686-netbsd
  • i686-none
  • i686-openbsd
  • i686-windows
  • javascript-ghcjs
  • loongarch64-linux
  • m68k-linux
  • m68k-netbsd
  • m68k-none
  • microblaze-linux
  • microblaze-none
  • microblazeel-linux
  • microblazeel-none
  • mips-linux
  • mips-none
  • mips64-linux
  • mips64-none
  • mips64el-linux
  • mipsel-linux
  • mipsel-netbsd
  • mmix-mmixware
  • msp430-none
  • or1k-none
  • powerpc-netbsd
  • powerpc-none
  • powerpc64-linux
  • powerpc64le-linux
  • powerpcle-none
  • riscv32-linux
  • riscv32-netbsd
  • riscv32-none
  • riscv64-linux
  • riscv64-netbsd
  • riscv64-none
  • rx-none
  • s390-linux
  • s390-none
  • s390x-linux
  • s390x-none
  • vc4-none
  • wasm32-wasi
  • wasm64-wasi
  • x86_64-cygwin
  • x86_64-darwin
  • x86_64-freebsd
  • x86_64-genode
  • x86_64-linux
  • x86_64-netbsd
  • x86_64-none
  • x86_64-openbsd
  • x86_64-redox
  • x86_64-solaris
  • x86_64-windows