The bartcs package finds confounders and treatment effect with Bayesian Additive Regression Trees (BART).
This tutorial will use The Infant Health and Development Program (IHDP) dataset. IHDP is a randomized experiment from 1985 to 1988 which studied the effect of home visits on cognitive test scores for infants. Our version is a synthetic dataset generated by Louizos et al. (2017) which provides true values to compare with. The dataset consists of 6 continuous and 19 binary covariates with simulated outcome which is a cognitive test score.
data(ihdp, package = "bartcs")
fit <- single_bart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 10,
num_chain = 4,
num_post_sample = 100,
num_burn_in = 100,
verbose = FALSE
#> `bartcs` fit by `single_bart()`
#> mean 2.5% 97.5%
#> ATE 3.989180 3.757940 4.194852
#> Y1 6.414511 6.209326 6.594922
#> Y0 2.425331 2.352195 2.512516
You can get mean and 95% credible interval of average treatment effect (ATE) and possible outcome Y1 and Y0.
ATE <- mean(ihdp$mu1 - ihdp$mu0)
#> [1] 4.016067
mu1 <- mean(ihdp$mu1)
#> [1] 6.44858
mu0 <- mean(ihdp$mu0)
#> [1] 2.432513
mse <- mean((unlist(fit$mcmc_list[, "ATE"]) - ATE)^2)
#> [1] 0.01311197
True values of ATE, mu1 and mu0 all lies in 95% credible interval. Also, MSE between predicted and true values is 0.013.
Both separate_bart()
and single_bart()
multiple MCMC chains. summary()
provides result for each
and aggregated chain.
#> `bartcs` fit by `single_bart()`
#> Treatment Value
#> Treated group : 1
#> Control group : 0
#> Tree Parameters
#> Number of Tree : 10 Value of alpha : 0.95
#> Prob. of Grow : 0.28 Value of beta : 2
#> Prob. of Prune : 0.28 Value of nu : 3
#> Prob. of Change : 0.44 Value of q : 0.95
#> Chain Parameters
#> Number of Chains : 4 Number of burn-in : 100
#> Number of Iter : 200 Number of thinning : 1
#> Number of Sample : 100
#> Outcome
#> estimand chain 2.5% 1Q mean median 3Q 97.5%
#> ATE 1 3.773543 3.911107 3.972227 3.970111 4.056683 4.151239
#> ATE 2 3.772706 3.941568 3.999966 4.008911 4.055016 4.193598
#> ATE 3 3.741889 3.859122 3.943162 3.959975 4.021592 4.118731
#> ATE 4 3.819920 3.971550 4.041364 4.045832 4.111067 4.252098
#> ATE agg 3.757940 3.914043 3.989180 3.990748 4.061248 4.194852
#> Y1 1 6.215452 6.328724 6.396984 6.404919 6.460931 6.543268
#> Y1 2 6.261225 6.359134 6.426559 6.434194 6.486718 6.590550
#> Y1 3 6.184771 6.320276 6.377308 6.382648 6.450492 6.545511
#> Y1 4 6.245013 6.393345 6.457192 6.465348 6.513756 6.672324
#> Y1 agg 6.209326 6.347214 6.414511 6.423418 6.480329 6.594922
#> Y0 1 2.342513 2.399295 2.424757 2.423876 2.451448 2.506433
#> Y0 2 2.355629 2.398670 2.426593 2.426521 2.448976 2.495499
#> Y0 3 2.357711 2.401557 2.434146 2.431961 2.460607 2.517344
#> Y0 4 2.346422 2.389331 2.415827 2.412804 2.441819 2.511036
#> Y0 agg 2.352195 2.397022 2.425331 2.424030 2.450444 2.512516
You can get posterior inclusion probability for each variables.
Since inclusion_plot()
is a wrapper function of
, you can use its arguments for better
With trace_plot()
, you can visually check trace of
effects or other parameters.
You can also use functions from coda
package for
components of bartcs
object. Components with
prefix are mcmc.list
object from
Check whether OpenMP is supported. You need more than 1 thread for multi-threading. Due to overhead of multi-threading, using parallelization will NOT be effective with small and moderate size datasets.
For comparison purpose, I will create dataset with 40,000 rows by bootstrapping from IHDP dataset. Then, for fast computation, I will fit the model with most parameters set to 1.
idx <- sample(nrow(ihdp), 4e4, TRUE)
ihdp <- ihdp[idx, ]
simple = single_bart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 1,
num_chain = 1,
num_post_sample = 10,
num_burn_in = 0,
verbose = FALSE,
parallel = FALSE
parallel = single_bart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 1,
num_chain = 1,
num_post_sample = 10,
num_burn_in = 0,
verbose = FALSE,
parallel = TRUE
times = 50
#> Unit: milliseconds
#> expr min lq mean median uq max neval
#> simple 92.06316 102.56569 113.35694 110.76865 117.82958 283.0242 50
#> parallel 72.70847 80.07966 85.21757 84.51305 87.92134 107.7117 50
Result show that parallelization reduces computation time.