--- title: "checkpointing: Stan" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{checkpointing: Stan} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>" ) ``` The following examples walk through using **chkptstanr** with the [**Stan**](https://mc-stan.org/users/interfaces/rstan) The basic idea is to (1) write a custom **Stan** model (done by the user), (2) fit the model with [**cmdstanr**](https://mc-stan.org/cmdstanr/) (with the desired number of checkpoints), and then (3) return a `cmststanr` object. All but step (1) is done internally, so the workflow is very similar to using **cmdstanr**. ## Packages ```r library(chkptstanr) library(posterior) library(bayesplot) ``` # Example 1: Eight Schools ## Storage The initial overhead is to create a folder that will store the checkpoints, i.e., ```r path <- create_folder(folder_name = "chkpt_folder_m1") ``` ## Stan Model Next is the Stan model: ```r stan_code <- " data { int n; real y[n]; real sigma[n]; } parameters { real mu; real tau; vector[n] eta; } transformed parameters { vector[n] theta; theta = mu + tau * eta; } model { target += normal_lpdf(eta | 0, 1); target += normal_lpdf(y | theta, sigma); } " ``` ## Stan Data When using `chkpt_stan()`, this requires supplying a list to the `data` argument, much like using *rstan*. ```r stan_data <- schools.data <- list( n = 8, y = c(28, 8, -3, 7, -1, 1, 18, 12), sigma = c(15, 10, 16, 11, 9, 11, 10, 18) ) ``` ## Model Fitting ### 2 Checkpoints To show the basic idea of checkpointing, the following was stopped after 2 checkpoints. ```r fit_m1 <- chkpt_stan(model_code = stan_code, data = stan_data, iter_warmup = 1000, iter_sampling = 1000, iter_per_chkpt = 250, path = path) #> Compiling Stan program... #> Initial Warmup (Typical Set) #> Chkpt: 1 / 8; Iteration: 250 / 2000 (warmup) #> Chkpt: 2 / 8; Iteration: 500 / 2000 (warmup) ``` ### Finish Sampling To finish the remaining 6 checkpoints run the same code, i.e., ```r fit_m1 <- chkpt_stan(model_code = stan_code, data = stan_data, iter_warmup = 1000, iter_sampling = 1000, iter_per_chkpt = 250, path = path) #> Sampling next checkpoint #> Chkpt: 3 / 8; Iteration: 750 / 2000 (warmup) #> Chkpt: 4 / 8; Iteration: 1000 / 2000 (warmup) #> Chkpt: 5 / 8; Iteration: 1250 / 2000 (sample) #> Chkpt: 6 / 8; Iteration: 1500 / 2000 (sample) #> Chkpt: 7 / 8; Iteration: 1750 / 2000 (sample) #> Chkpt: 8 / 8; Iteration: 2000 / 2000 (sample) #> Checkpointing complete ``` ### Combine Draws Each checkpoint contains 250 draws from the posterior. These need to be combined with `combine_chkpt_draws()`, i.e., ``` draws <- combine_chkpt_draws(fit_m1) ``` We developed **chkptstanr** to work seamlessly with the **Stan** ecosystem. The object `draws` has been constructed to mimic what is provided when using **cmdstanr** directly. ```r combine_chkpt_draws(fit_m1) #> # A draws_array: 1000 iterations, 2 chains, and 19 variables #> , , variable = lp__ #> #> chain #> iteration 1 2 #> 1 -34 -43 #> 2 -37 -41 #> 3 -36 -39 #> 4 -38 -38 #> 5 -38 -41 #> #> , , variable = mu #> #> chain #> iteration 1 2 #> 1 5.2 2.6 #> 2 11.3 6.7 #> 3 -2.7 5.3 #> 4 -2.9 3.7 #> 5 -2.7 14.2 #> #> , , variable = tau #> #> chain #> iteration 1 2 #> 1 23.3 2.61 #> 2 6.7 0.21 #> 3 12.7 4.44 #> 4 21.1 7.29 #> 5 18.8 10.94 #> #> , , variable = eta[1] #> #> chain #> iteration 1 2 #> 1 0.10 -0.61 #> 2 0.89 -0.87 #> 3 1.62 0.83 #> 4 1.99 0.84 #> 5 -0.16 1.22 #> #> # ... with 995 more iterations, and 15 more variables ``` ### Summary `draws` can then be used with the `R` package [**posterior**](https://mc-stan.org/posterior/) ```r posterior::summarise_draws(draws) #> # A tibble: 19 x 10 #> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail #> #> 1 lp__ -39.5 -39.2 2.59 2.58 -44.2 -35.9 1.00 640. 1008. #> 2 mu 7.77 7.92 5.48 5.10 -1.43 16.0 1.01 530. 325. #> 3 tau 6.82 5.32 5.75 4.71 0.434 18.7 1.00 649. 658. #> 4 eta[1] 0.383 0.413 0.929 0.909 -1.20 1.87 1.00 1650. 1233. #> 5 eta[2] -0.00335 -0.00816 0.841 0.814 -1.34 1.40 1.00 1443. 1307. #> 6 eta[3] -0.176 -0.174 0.931 0.906 -1.67 1.42 1.00 1829. 1424. #> 7 eta[4] -0.00521 0.000856 0.862 0.841 -1.47 1.39 1.00 1565. 1407. #> 8 eta[5] -0.312 -0.350 0.873 0.835 -1.72 1.24 1.00 1661. 1616. #> 9 eta[6] -0.193 -0.190 0.889 0.909 -1.59 1.28 1.00 1915. 1404. #> 10 eta[7] 0.387 0.358 0.876 0.864 -1.09 1.81 1.00 1574. 1370. #> 11 eta[8] 0.0805 0.0611 0.970 0.960 -1.51 1.66 1.00 1031. 1236. #> 12 theta[1] 11.5 10.2 8.29 6.99 0.268 26.4 1.00 1042. 728. #> 13 theta[2] 7.87 7.87 6.20 5.66 -2.27 17.8 1.00 1549. 1515. #> 14 theta[3] 6.01 6.63 8.25 6.63 -8.69 18.1 1.00 1102. 1075. #> 15 theta[4] 7.75 7.76 6.65 5.96 -3.06 18.9 1.00 1674. 1210. #> 16 theta[5] 5.05 5.70 6.44 5.75 -7.06 14.4 1.00 1405. 1416. #> 17 theta[6] 6.21 6.60 6.92 6.15 -5.98 16.9 1.00 1890. 1195. #> 18 theta[7] 10.8 10.1 6.71 6.03 0.992 23.1 1.00 1497. 1767. #> 19 theta[8] 8.35 8.41 7.72 6.66 -3.88 20.7 1.00 1081. 1075. ``` ### Visualization with bayesplot The popular `R` package [**bayesplot**](https://mc-stan.org/bayesplot/) can also be used. ```r bayesplot::mcmc_trace(draws) + geom_vline(xintercept = seq(0, 1000, 250), alpha = 0.25, size = 2) ``` ![](../man/figures/stan_f1.png) This vertical lines are placed at each checkpoint.