--- title: "checkpointing: brms" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{checkpointing: brms} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>" ) ``` The following examples walk through using **chkptstanr** with the popular `R` package [**brms**](https://paul-buerkner.github.io/brms/). The basic idea is to (1) generate the [**Stan**](https://mc-stan.org/users/interfaces/rstan) code with **brms**, (2) fit the model with [**cmdstanr**](https://mc-stan.org/cmdstanr/) (with the desired number of checkpoints), and then (3) return a `brmsfit` object. This is all done internally, so the workflow is very similar to using **brms**. ## Packages ```r library(chkptstanr) library(posterior) library(bayesplot) library(ggplot2) library(brms) ``` # Example 1: No Stopping ## Storage The initial overhead is to create a folder that will store the checkpoints, i.e., ```r path <- create_folder(folder_name = "chkpt_folder_m1") ``` which contains several additional folders (details can be found in the documentation). ## `brmsformula` In this example, we create a `brmsformula` object using `bf()`. Note that for this model, we could also use formula argument (e.g., `formula = y ~ x`), but in our experiences `bf()` is more general. ```r bf_m1 <- bf(formula = count ~ zAge + zBase + (1 | patient), family = poisson()) ``` ## Model Fitting The next step is to use `chkpt_brms()`: ```r fit_m1 <- chkpt_brms( formula = bf_m1, data = epilepsy, path = path, iter_warmup = 1000, iter_sampling = 1000, iter_per_chkpt = 250, ) ``` When running the above, a custom progress bar is printed that includes information about the checkpoints. ```r #> Compiling Stan program... #> Initial Warmup (Typical Set) #> Chkpt: 1 / 8; Iteration: 250 / 2000 (warmup) #> Chkpt: 2 / 8; Iteration: 500 / 2000 (warmup) #> Chkpt: 3 / 8; Iteration: 750 / 2000 (warmup) #> Chkpt: 4 / 8; Iteration: 1000 / 2000 (warmup) #> Chkpt: 5 / 8; Iteration: 1250 / 2000 (sample) #> Chkpt: 6 / 8; Iteration: 1500 / 2000 (sample) #> Chkpt: 7 / 8; Iteration: 1750 / 2000 (sample) #> Chkpt: 8 / 8; Iteration: 2000 / 2000 (sample) #> Checkpointing complete ``` In this case, checkpointing is complete. ## Summary `fit_m1` is a `brmsfit` object which means that all of the functionality of **brms** can still be used. Here is the summary output: ```r fit_m1 #> Family: poisson #> Links: mu = log #> Formula: count ~ zAge + zBase + (1 | patient) #> Data: data (Number of observations: 236) #> Draws: 2 chains, each with iter = 1000; warmup = 0; thin = 1; #> total post-warmup draws = 2000 #> #> Group-Level Effects: #> ~patient (Number of levels: 59) #> Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS #> sd(Intercept) 0.58 0.07 0.46 0.73 1.00 349 682 #> Population-Level Effects: #> Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS #> Intercept 1.63 0.08 1.46 1.78 1.01 406 898 #> zAge 0.11 0.09 -0.06 0.27 1.00 463 796 #> zBase 0.73 0.08 0.58 0.89 1.00 613 814 #> #> Draws were sampled using sample(hmc). For each parameter, Bulk_ESS #> and Tail_ESS are effective sample size measures, and Rhat is the potential #> scale reduction factor on split chains (at convergence, Rhat = 1). ``` ## Posterior Predictive Check Of course, due to being a `brmsfit` object, it is seamless perform a posterior predictive check. ```r pp_check(fit_m1) ``` ![](../man/figures/pp_check_f1.png) # Example 2: Start, Stop, Start, etc. The previous example could just as well be fitted directly with **brms**. This is because the MCMC sampler was not stopped during model fitting. In the following example, we illustrate the usefulness of **chkptstanr**, i.e., the ability to stop the MCMC sampler at will, and then pick right back up where the MCMC sampler left off. ## Storage The initial overhead is to create a folder that will store the checkpoints, i.e., ```r path <- create_folder(folder_name = "chkpt_folder_m2") ``` ## Model Fitting This model is mostly the same as above. The one difference is that it does not include varying ("random") intercepts. ### Start and Stop: Two Checkpoints To illustrate checkpointing, the following was stopped after 2 checkpoints. ```r fit_m2 <- chkpt_brms( bf(formula = count ~ zAge + zBase, family = poisson()), data = epilepsy, path = path, iter_warmup = 1000, iter_sampling = 1000, iter_per_chkpt = 250, ) #> Compiling Stan program... #> Initial Warmup (Typical Set) #> Chkpt: 1 / 8; Iteration: 250 / 2000 (warmup) #> Chkpt: 2 / 8; Iteration: 500 / 2000 (warmup) ``` Note this was stopped by clicking on the red button aptly titled stop (in the console). This is but one use case, for example, needing to do something else but not wanting to loose the progress (including the compiled model). Another use case is scheduling, such that the model samples during certain times until completion. ### Start and Stop: Two More Checkpoints Now pick up at the next checkpoint. This is accomplished by simply running the same code. ```r fit_m2 <- chkpt_brms( formula = bf(formula = count ~ zAge + zBase, family = poisson()), data = epilepsy, path = path, iter_warmup = 1000, iter_sampling = 1000, iter_per_chkpt = 250, ) #> Sampling next checkpoint #> Chkpt: 3 / 8; Iteration: 750 / 2000 (warmup) #> Chkpt: 4 / 8; Iteration: 1000 / 2000 (warmup) ``` Notice it picks up at right where it left off (stopped after 2 checkpoints) ### Start: Finish Checkpointing Now let us finish the remaining 4 checkpoints. ```r fit_m2 <- chkpt_brms( formula = bf(formula = count ~ zAge + zBase, family = poisson()), data = epilepsy, path = path, iter_warmup = 1000, iter_sampling = 1000, iter_per_chkpt = 250, ) #> Sampling next checkpoint #> Chkpt: 5 / 8; Iteration: 1250 / 2000 (sample) #> Chkpt: 6 / 8; Iteration: 1500 / 2000 (sample) #> Chkpt: 7 / 8; Iteration: 1750 / 2000 (sample) #> Chkpt: 8 / 8; Iteration: 2000 / 2000 (sample) #> Checkpointing complete ``` If we trying running the model again, we get the following message: ```r fit_m2 <- chkpt_brms( formula = bf(formula = count ~ zAge + zBase, family = poisson()), data = epilepsy, path = path, iter_warmup = 1000, iter_sampling = 1000, iter_per_chkpt = 250, ) #> Sampling next checkpoint #> Checkpointing complete ``` Note that the arguments need to be exactly the same when restarting. There is a check for `data`, `formula`, `iter_per_chkpt`, etc., and if they have been changed, this will produce an error (with an informative warning message). ## Diagnostics Some diagnostic information is provided in the summary output. ```r fit_m2 #> Family: poisson #> Links: mu = log #> Formula: count ~ zAge + zBase #> Data: data (Number of observations: 236) #> Draws: 2 chains, each with iter = 1000; warmup = 0; thin = 1; #> total post-warmup draws = 2000 #> #> Population-Level Effects: #> Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS #> Intercept 1.84 0.03 1.78 1.89 1.00 1037 1009 #> zAge 0.16 0.02 0.11 0.21 1.00 1192 945 #> zBase 0.60 0.01 0.58 0.63 1.00 1463 1559 #> #> Draws were sampled using sample(hmc). For each parameter, Bulk_ESS #> and Tail_ESS are effective sample size measures, and Rhat is the potential #> scale reduction factor on split chains (at convergence, Rhat = 1). ``` These diagnostics indicate the inference converged. ## More Diagnostics **cmdstanr** works with several packages in the **Stan** ecosystem, including [**posterior**](https://mc-stan.org/posterior/) and [**bayesplot**](https://mc-stan.org/bayesplot/). ```r # draws for bayesplot draws <- posterior::as_draws_array(fit_m2) # trace plot bayesplot::mcmc_trace(x = draws, pars = "b_zAge") + geom_vline(xintercept = seq(0, 1000, 250), alpha = 0.25, size = 2) ``` ![](../man/figures/trace_f2.png) This vertical lines are placed at each checkpoint. ## Model Comparison These models can then be compared with approximate leave-one-out cross-validation (via the `R` package [**loo**](http://mc-stan.org/loo/index.html)). ```r loo_compare(loo(fit_m1), loo(fit_m2)) #> elpd_diff se_diff #> fit_m1 0.0 0.0 #> fit_m2 -203.6 65.4 ``` # Compare to `brm` For a sanity check, here is `fit_m2` fitted with **brms**. The estimates should be (basically) the same. ```r fit_brms <- brm( formula = bf(formula = count ~ zAge + zBase, family = poisson()), data = epilepsy, chains = 2, iter = 2000 ) fit_brms #> Family: poisson #> Links: mu = log #> Formula: count ~ zAge + zBase #> Data: epilepsy (Number of observations: 236) #> Draws: 2 chains, each with iter = 2000; warmup = 1000; thin = 1; #> total post-warmup draws = 2000 #> #> Population-Level Effects: #> Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS #> Intercept 1.84 0.03 1.78 1.89 1.00 1247 1310 #> zAge 0.16 0.02 0.11 0.21 1.00 1226 1191 #> zBase 0.60 0.01 0.57 0.63 1.00 1107 1229 #> #> Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS #> and Tail_ESS are effective sample size measures, and Rhat is the potential #> scale reduction factor on split chains (at convergence, Rhat = 1). ``` The results for the parameter estimates and diagnostics are very similar (as expected). # Example 3: User Defined Priors `chkpt_brms()` includes `...` which passes any number of (valid) arguments to `brm()`. Accordingly, priors can be specified as though `brm()` was used. ```r path <- create_folder(folder_name = "chkpt_folder_m3") # priors bprior <- prior(constant(1), class = "b") + prior(constant(2), class = "b", coef = "zBase") + prior(constant(0.5), class = "sd") # fit model fit_m3 <- chkpt_brms( bf(formula = count ~ zAge + zBase + (1 | patient), family = poisson()), prior = bprior, data = epilepsy, path = path, iter_warmup = 1000, iter_sampling = 1000, iter_per_chkpt = 250, brmsfit = TRUE ) ``` `prior_summary()` can be used to confirm that the priors found their way into the model correctly, i.e., ```r prior_summary(fit_m3) #> prior class coef group resp dpar nlpar bound source #> constant(1) b user #> constant(1) b zAge (vectorized) #> constant(2) b zBase user #> student_t(3, 1.4, 2.5) Intercept default #> constant(0.5) sd user #> constant(0.5) sd patient (vectorized) #> constant(0.5) sd Intercept patient (vectorized) ```