MyNixOS website logo
Description

Generative Adversarial Nets (GAN) in R.

An easy way to get started with Generative Adversarial Nets (GAN) in R. The GAN algorithm was initially described by Goodfellow et al. 2014 <https://proceedings.neurips.cc/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf>. A GAN can be used to learn the joint distribution of complex data by comparison. A GAN consists of two neural networks a Generator and a Discriminator, where the two neural networks play an adversarial minimax game. Built-in GAN models make the training of GANs in R possible in one line and make it easy to experiment with different design choices (e.g. different network architectures, value functions, optimizers). The built-in GAN models work with tabular data (e.g. to produce synthetic data) and image data. Methods to post-process the output of GAN models to enhance the quality of samples are available.

RGAN

The goal of RGAN is to facilitate training and experimentation with Generative Adversarial Nets (GAN) in R.

Installation

You can install the released version of RGAN from CRAN with:

install.packages("RGAN")

And the development version from GitHub with:

# install.packages("devtools")
devtools::install_github("mneunhoe/RGAN")

Example

This is a basic example which shows you how to train a GAN and observe training progress on toy data.

Before running RGAN for the first time you need to make sure that torch is properly installed:

install.packages("torch")
#> Installing package into '/private/var/folders/z8/wk0vgp996m74v0g_x797qzf00000gn/T/RtmppUYncE/temp_libpath7cc3448fcda2'
#> (as 'lib' is unspecified)
#> 
#> The downloaded binary packages are in
#>  /var/folders/z8/wk0vgp996m74v0g_x797qzf00000gn/T//Rtmpa6rd9N/downloaded_packages
library(torch)

Then you can get started to train a GAN on toy data (or potentially your own data).

library(RGAN)

# Sample some toy data to play with.
data <- sample_toydata()

# Transform (here standardize) the data to facilitate learning.
# First, create a new data transformer.
transformer <- data_transformer$new()

# Fit the transformer to your data.
transformer$fit(data)

# Use the fitted transformer to transform your data.
transformed_data <- transformer$transform(data)

# Have a look at the transformed data.
par(mfrow = c(3, 2))
plot(
  transformed_data,
  bty = "n",
  col = viridis::viridis(2, alpha = 0.7)[1],
  pch = 19,
  xlab = "Var 1",
  ylab = "Var 2",
  main = "The Real Data",
  las = 1
)

# Set the device you want to train on.
# First, we check whether a compatible GPU is available for computation.
use_cuda <- torch::cuda_is_available()

# If so we would use it to speed up training (especially for models with image data).
device <- ifelse(use_cuda, "cuda", "cpu")

# Now train the GAN and observe some intermediate results.
res <-
  gan_trainer(
    transformed_data,
    eval_dropout = TRUE,
    plot_progress = TRUE,
    plot_interval = 600,
    device = device
  )
#> Training the GAN ■■                                 3% | ETA:  1m
#> Training the GAN ■■                                 5% | ETA:  1m
#> Training the GAN ■■■■                              10% | ETA:  1m
#> Training the GAN ■■■■■■                            16% | ETA: 48s
#> Training the GAN ■■■■■■■                           21% | ETA: 45s
#> Training the GAN ■■■■■■■■■                         26% | ETA: 42s
#> Training the GAN ■■■■■■■■■■                        32% | ETA: 39s
#> Training the GAN ■■■■■■■■■■■■                      37% | ETA: 36s
#> Training the GAN ■■■■■■■■■■■■■■                    42% | ETA: 32s
#> Training the GAN ■■■■■■■■■■■■■■■                   48% | ETA: 30s
#> Training the GAN ■■■■■■■■■■■■■■■■■                 53% | ETA: 27s
#> Training the GAN ■■■■■■■■■■■■■■■■■■                58% | ETA: 24s
#> Training the GAN ■■■■■■■■■■■■■■■■■■■■              63% | ETA: 21s
#> Training the GAN ■■■■■■■■■■■■■■■■■■■■■             68% | ETA: 19s
#> Training the GAN ■■■■■■■■■■■■■■■■■■■■■■■           73% | ETA: 16s
#> Training the GAN ■■■■■■■■■■■■■■■■■■■■■■■■          78% | ETA: 13s
#> Training the GAN ■■■■■■■■■■■■■■■■■■■■■■■■■■        83% | ETA: 10s
#> Training the GAN ■■■■■■■■■■■■■■■■■■■■■■■■■■■       88% | ETA:  7s
#> Training the GAN ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■     93% | ETA:  4s
#> Training the GAN ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■    98% | ETA:  1s
#> Training the GAN ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■  100% | ETA:  0s

After training you can work with the resulting GAN to sample synthetic data, or potentially keep training for further steps.

If you want to sample synthetic data from your GAN you need to provide a GAN Generator and a noise vector (that needs to be a torch tensor and should come from the same distribution that you used during training). For example, we could look at the difference of synthetic data generated with and without dropout during generation/inference (using the same noise vector).

par(mfrow = c(1, 2))

# Set the noise vector.
noise_vector <- torch::torch_randn(c(nrow(transformed_data), 2))$to(device = device)

# Generate synthetic data from the trained generator with dropout during generation.
synth_data_dropout <- expert_sample_synthetic_data(res$generator, noise_vector,eval_dropout = TRUE)

# Plot data and synthetic data
GAN_update_plot(data = transformed_data, synth_data = synth_data_dropout, main = "With dropout")

synth_data_no_dropout <- expert_sample_synthetic_data(res$generator, noise_vector,eval_dropout = F)

GAN_update_plot(data = transformed_data, synth_data = synth_data_no_dropout, main = "Without dropout")

If you want to continue training you can pass the generator, discriminator as well as the respective optimizers to gan_trainer like that:

res_cont <- gan_trainer(transformed_data,
                   generator = res$generator,
                   discriminator = res$discriminator,
                   generator_optimizer = res$generator_optimizer,
                   discriminator_optimizer = res$discriminator_optimizer,
                   epochs = 10
                   )
#> Training the GAN ■■■■■■■■■■■■■■■■                  50% | ETA:  2s
#> Training the GAN ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■  100% | ETA:  0s
Metadata

Version

0.1.1

License

Unknown

Platforms (75)

    Darwin
    FreeBSD
    Genode
    GHCJS
    Linux
    MMIXware
    NetBSD
    none
    OpenBSD
    Redox
    Solaris
    WASI
    Windows
Show all
  • aarch64-darwin
  • aarch64-genode
  • aarch64-linux
  • aarch64-netbsd
  • aarch64-none
  • aarch64_be-none
  • arm-none
  • armv5tel-linux
  • armv6l-linux
  • armv6l-netbsd
  • armv6l-none
  • armv7a-darwin
  • armv7a-linux
  • armv7a-netbsd
  • armv7l-linux
  • armv7l-netbsd
  • avr-none
  • i686-cygwin
  • i686-darwin
  • i686-freebsd
  • i686-genode
  • i686-linux
  • i686-netbsd
  • i686-none
  • i686-openbsd
  • i686-windows
  • javascript-ghcjs
  • loongarch64-linux
  • m68k-linux
  • m68k-netbsd
  • m68k-none
  • microblaze-linux
  • microblaze-none
  • microblazeel-linux
  • microblazeel-none
  • mips-linux
  • mips-none
  • mips64-linux
  • mips64-none
  • mips64el-linux
  • mipsel-linux
  • mipsel-netbsd
  • mmix-mmixware
  • msp430-none
  • or1k-none
  • powerpc-netbsd
  • powerpc-none
  • powerpc64-linux
  • powerpc64le-linux
  • powerpcle-none
  • riscv32-linux
  • riscv32-netbsd
  • riscv32-none
  • riscv64-linux
  • riscv64-netbsd
  • riscv64-none
  • rx-none
  • s390-linux
  • s390-none
  • s390x-linux
  • s390x-none
  • vc4-none
  • wasm32-wasi
  • wasm64-wasi
  • x86_64-cygwin
  • x86_64-darwin
  • x86_64-freebsd
  • x86_64-genode
  • x86_64-linux
  • x86_64-netbsd
  • x86_64-none
  • x86_64-openbsd
  • x86_64-redox
  • x86_64-solaris
  • x86_64-windows