The bartcs package finds confounders and treatment effect with Bayesian Additive Regression Trees (BART).
This tutorial will use The Infant Health and Development Program (IHDP) dataset. IHDP is a randomized experiment from 1985 to 1988 which studied the effect of home visits on cognitive test scores for infants. Our version is a synthetic dataset generated by Louizos et al. (2017) which provides true values to compare with. The dataset consists of 6 continuous and 19 binary covariates with simulated outcome which is a cognitive test score.
data(ihdp, package = "bartcs")
fit <- single_bart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 10,
num_chain = 4,
num_post_sample = 100,
num_burn_in = 100,
verbose = FALSE
)
fit
#> `bartcs` fit by `single_bart()`
#>
#> mean 2.5% 97.5%
#> ATE 3.989180 3.757940 4.194852
#> Y1 6.414511 6.209326 6.594922
#> Y0 2.425331 2.352195 2.512516
You can get mean and 95% credible interval of average treatment effect (ATE) and possible outcome Y1 and Y0.
ATE <- mean(ihdp$mu1 - ihdp$mu0)
ATE
#> [1] 4.016067
mu1 <- mean(ihdp$mu1)
mu1
#> [1] 6.44858
mu0 <- mean(ihdp$mu0)
mu0
#> [1] 2.432513
mse <- mean((unlist(fit$mcmc_list[, "ATE"]) - ATE)^2)
mse
#> [1] 0.01311197
True values of ATE, mu1 and mu0 all lies in 95% credible interval. Also, MSE between predicted and true values is 0.013.
Both separate_bart()
and single_bart()
fits
multiple MCMC chains. summary()
provides result for each
and aggregated chain.
summary(fit)
#> `bartcs` fit by `single_bart()`
#>
#> Treatment Value
#> Treated group : 1
#> Control group : 0
#>
#> Tree Parameters
#> Number of Tree : 10 Value of alpha : 0.95
#> Prob. of Grow : 0.28 Value of beta : 2
#> Prob. of Prune : 0.28 Value of nu : 3
#> Prob. of Change : 0.44 Value of q : 0.95
#>
#> Chain Parameters
#> Number of Chains : 4 Number of burn-in : 100
#> Number of Iter : 200 Number of thinning : 1
#> Number of Sample : 100
#>
#> Outcome
#> estimand chain 2.5% 1Q mean median 3Q 97.5%
#> ATE 1 3.773543 3.911107 3.972227 3.970111 4.056683 4.151239
#> ATE 2 3.772706 3.941568 3.999966 4.008911 4.055016 4.193598
#> ATE 3 3.741889 3.859122 3.943162 3.959975 4.021592 4.118731
#> ATE 4 3.819920 3.971550 4.041364 4.045832 4.111067 4.252098
#> ATE agg 3.757940 3.914043 3.989180 3.990748 4.061248 4.194852
#> Y1 1 6.215452 6.328724 6.396984 6.404919 6.460931 6.543268
#> Y1 2 6.261225 6.359134 6.426559 6.434194 6.486718 6.590550
#> Y1 3 6.184771 6.320276 6.377308 6.382648 6.450492 6.545511
#> Y1 4 6.245013 6.393345 6.457192 6.465348 6.513756 6.672324
#> Y1 agg 6.209326 6.347214 6.414511 6.423418 6.480329 6.594922
#> Y0 1 2.342513 2.399295 2.424757 2.423876 2.451448 2.506433
#> Y0 2 2.355629 2.398670 2.426593 2.426521 2.448976 2.495499
#> Y0 3 2.357711 2.401557 2.434146 2.431961 2.460607 2.517344
#> Y0 4 2.346422 2.389331 2.415827 2.412804 2.441819 2.511036
#> Y0 agg 2.352195 2.397022 2.425331 2.424030 2.450444 2.512516
You can get posterior inclusion probability for each variables.
Since inclusion_plot()
is a wrapper function of
ggcharts::bar_chart()
, you can use its arguments for better
plot.
With trace_plot()
, you can visually check trace of
effects or other parameters.
You can also use functions from coda
package for
components of bartcs
object. Components with
mcmc_
prefix are mcmc.list
object from
coda
package.
Check whether OpenMP is supported. You need more than 1 thread for multi-threading. Due to overhead of multi-threading, using parallelization will NOT be effective with small and moderate size datasets.
For comparison purpose, I will create dataset with 40,000 rows by bootstrapping from IHDP dataset. Then, for fast computation, I will fit the model with most parameters set to 1.
idx <- sample(nrow(ihdp), 4e4, TRUE)
ihdp <- ihdp[idx, ]
microbenchmark::microbenchmark(
simple = single_bart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 1,
num_chain = 1,
num_post_sample = 10,
num_burn_in = 0,
verbose = FALSE,
parallel = FALSE
),
parallel = single_bart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 1,
num_chain = 1,
num_post_sample = 10,
num_burn_in = 0,
verbose = FALSE,
parallel = TRUE
),
times = 50
)
#> Unit: milliseconds
#> expr min lq mean median uq max neval
#> simple 92.70185 104.79952 114.58517 111.7027 118.02281 255.8088 50
#> parallel 71.05865 80.39887 85.85849 83.7373 90.23573 127.7525 50
Result show that parallelization reduces computation time.