MyNixOS website logo
Description

Outcome Weights of Treatment Effect Estimators.

Many treatment effect estimators can be written as weighted outcomes. These weights have established use cases like checking covariate balancing via packages like 'cobalt'. This package takes the original estimator objects and outputs these outcome weights. It builds on the general framework of Knaus (2024) <doi:10.48550/arXiv.2411.11559>. This version is compatible with the 'grf' package and provides an internal implementation of Double Machine Learning.

Outcome Weights

CRAN_Status_Badge License Downloads_Total Downloads_Monthly Project_Status

This R package calculates the outcome weights of Knaus (2024). Its use is illustrated in the average effects R notebook and the heterogeneous effects R notebook as supplementary material to the paper.

The core functionality is the get_outcome_weights() method implementing the theoretical result in Proposition 1 of the paper. It shows that the outcome weights vector can be obtained in the general form $\boldsymbol{\omega'} = (\boldsymbol{\tilde{Z}'\tilde{D}})^{-1} \boldsymbol{\tilde{Z}'T}$ where $\boldsymbol{\tilde{Z}}$, $\boldsymbol{\tilde{D}}$ and $\boldsymbol{T}$ are pseudo-instrument, pseudo-treatment and the transformation matrix, respectively.

In the future it should be compatible with as many estimated R objects as possible.

The package can be downloaded from CRAN:

install.packages("OutcomeWeights")

The package is work in progress. Find here the current state (suggestions welcome):

In progress

  • [ ] Compatibility with grf package
    • [x] causal_forest() outcome weights for CATE
    • [x] instrumental_forest() outcome weights CLATE
    • [x] causal_forest() outcome weights for ATE from average_treatment_effect()
    • [ ] All outcome weights for average parameters compatible with average_treatment_effect()
  • [ ] Package internal Double ML implementation handling the required outcome smoother matrices
    • [x] Nuisance parameter estimation based on honest random forest (regression_forest() of grf package)
    • [x] dml_with_smoother() function runs for PLR, PLR-IV, AIPW-ATE, and Wald_AIPW and is compatible with get_outcome_weights()
    • [ ] Add more Double ML estimators
    • [ ] Add support for more smoothers

Envisioned features

  • [ ] Compatibility with DoubleML (this is a non-trivial task as the mlr3 environment it builds on does not provide smoother matrices)
    • [ ] Extract the smoother matrices of mlr3 available, where possible
    • [ ] Make the smoother matrices of mlr3 accessible within DoubleML
    • [ ] Write get_outcome_weights() method for DoubleML estimators
  • [ ] Collect packages where weights could be extracted and implement them

The following code creates synthetic data to showcase how causal forest weights are extracted and that they perfectly replicate the original output:

if (!require("OutcomeWeights")) install.packages("OutcomeWeights", dependencies = TRUE)
library(OutcomeWeights)

# Sample from DGP borrowed from grf documentation
n = 500
p = 10
X = matrix(rnorm(n * p), n, p)
W = rbinom(n, 1, 0.5)
Y = pmax(X[, 1], 0) * W + X[, 2] + pmin(X[, 3], 0) + rnorm(n)

# Run outcome regression and extract smoother matrix
forest.Y = grf::regression_forest(X, Y)
Y.hat = predict(forest.Y)$predictions
outcome_smoother = grf::get_forest_weights(forest.Y)

# Run causal forest with external Y.hats
c.forest = grf::causal_forest(X, Y, W, Y.hat = Y.hat)

# Predict on out-of-bag training samples.
cate.oob = predict(c.forest)$predictions

# Predict using the forest.
X.test = matrix(0, 101, p)
X.test[, 1] = seq(-2, 2, length.out = 101)
cate.test = predict(c.forest, X.test)$predictions

# Calculate outcome weights
omega_oob = get_outcome_weights(c.forest, S = outcome_smoother)
omega_test = get_outcome_weights(c.forest, S = outcome_smoother, newdata = X.test)

# Observe that they perfectly replicate the original CATEs
all.equal(as.numeric(omega_oob$omega %*% Y), 
          as.numeric(cate.oob))
all.equal(as.numeric(omega_test$omega %*% Y), 
          as.numeric(cate.test))

# Also the ATE estimates are perfectly replicated
omega_ate = get_outcome_weights(c.forest,target = "ATE", S = outcome_smoother, S.tau = omega_oob$omega)
all.equal(as.numeric(omega_ate$omega %*% Y),
          as.numeric(grf::average_treatment_effect(c.forest, target.sample = "all")[1]))

The development version is available using the devtools package:

library(devtools)
install_github(repo="MCKnaus/OutcomeWeights")

References

Knaus, M. C. (2024). Treatment effect estimators as weighted outcomes, arXiv:2411.11559

Metadata

Version

0.1.1

License

Unknown

Platforms (77)

    Darwin
    FreeBSD
    Genode
    GHCJS
    Linux
    MMIXware
    NetBSD
    none
    OpenBSD
    Redox
    Solaris
    WASI
    Windows
Show all
  • aarch64-darwin
  • aarch64-freebsd
  • aarch64-genode
  • aarch64-linux
  • aarch64-netbsd
  • aarch64-none
  • aarch64-windows
  • aarch64_be-none
  • arm-none
  • armv5tel-linux
  • armv6l-linux
  • armv6l-netbsd
  • armv6l-none
  • armv7a-darwin
  • armv7a-linux
  • armv7a-netbsd
  • armv7l-linux
  • armv7l-netbsd
  • avr-none
  • i686-cygwin
  • i686-darwin
  • i686-freebsd
  • i686-genode
  • i686-linux
  • i686-netbsd
  • i686-none
  • i686-openbsd
  • i686-windows
  • javascript-ghcjs
  • loongarch64-linux
  • m68k-linux
  • m68k-netbsd
  • m68k-none
  • microblaze-linux
  • microblaze-none
  • microblazeel-linux
  • microblazeel-none
  • mips-linux
  • mips-none
  • mips64-linux
  • mips64-none
  • mips64el-linux
  • mipsel-linux
  • mipsel-netbsd
  • mmix-mmixware
  • msp430-none
  • or1k-none
  • powerpc-netbsd
  • powerpc-none
  • powerpc64-linux
  • powerpc64le-linux
  • powerpcle-none
  • riscv32-linux
  • riscv32-netbsd
  • riscv32-none
  • riscv64-linux
  • riscv64-netbsd
  • riscv64-none
  • rx-none
  • s390-linux
  • s390-none
  • s390x-linux
  • s390x-none
  • vc4-none
  • wasm32-wasi
  • wasm64-wasi
  • x86_64-cygwin
  • x86_64-darwin
  • x86_64-freebsd
  • x86_64-genode
  • x86_64-linux
  • x86_64-netbsd
  • x86_64-none
  • x86_64-openbsd
  • x86_64-redox
  • x86_64-solaris
  • x86_64-windows