Model summaries for a Bayesian linear regression

The packages mmcc creates tidy summaries of Bayesian models, in the fashion of broom, with one important difference - mmcc uses a data.table instead of a tibble, due to the size of the output that is all too easily possible in Bayesian models.

The aim of this vignette is to demonstrate how to use the two key functions of mmcc, mcmc_to_dt() and tidy (which actually calls mmcc:::tidy.mcmc.list under the hood).

First, we simulate some data to fit a Bayesian model to.

set.seed(4000)
N <- 20
x <- sort(runif(n = N))
y <- rnorm(n = N, 
           mean = 2*x + 1, 
           sd = 0.25)
dat <- data.frame(x = x, y = y)

library(ggplot2)
ggplot(data = dat,
       aes(x = x, 
           y = y)) +
    geom_point() +
    theme_bw()

Then, we simulate some values for predicting

M <- 10
x_pred <- seq(from = min(x), 
              to = max(x), 
              length.out = M)

Next, we fit the model, specified as

jags_model <- 
"model{
    # model block
    for (i in 1:n){
        y[i] ~ dnorm(mu[i], tau_y)
        mu[i] <- beta_0 + beta_1*x[i]
    }

    # prediction block
    for (i in 1:m){
        y_pred[i] ~ dnorm(mu_pred[i], tau_y)
        mu_pred[i] <- beta_0 + beta_1*x_pred[i]
    }

    # priors
    beta_0 ~ dunif(-1e12, 1e12)
    beta_1 ~ dunif(-1e12, 1e12)
    tau_y <- exp(2*log_sigma)
    log_sigma ~ dunif(-1e12, 1e12)
}"

and then generate the mcmc_object with the rjags package.

library(rjags)
#> Loading required package: coda
#> Linked to JAGS 4.3.2
#> Loaded modules: basemod,bugs
model <- jags.model(file = textConnection(jags_model),
                    data = list(n = N,
                                x = x,
                                y = y,
                                m = M,
                                x_pred = x_pred),
                    n.chains = 3)
#> Compiling model graph
#>    Resolving undeclared variables
#>    Allocating nodes
#> Graph information:
#>    Observed stochastic nodes: 20
#>    Unobserved stochastic nodes: 13
#>    Total graph size: 126
#> 
#> Initializing model

We draw burn-in samples and posterior inference samples for all terms in the model.

burn <- jags.samples(model = model,
                     variable.names = c("beta_0", 
                                        "beta_1", 
                                        "tau_y", 
                                        "mu"),
                     n.iter = 5000)

samples <- coda.samples(model = model,
                        variable.names = c("beta_0", 
                                           "beta_1", 
                                           "tau_y", 
                                           "mu_pred", 
                                           "y_pred"),
                        n.iter = 10000)

We can now convert the posterior samples to a data.table and summarise the regression parameters. A data.table object is very useful in this case when you have many samples for many parameters.

library(mmcc)

# convert to a data.table
samples_dt <- mcmc_to_dt(samples)
samples_dt
#>         iteration chain  parameter    value
#>             <int> <int>     <fctr>    <num>
#>      1:         1     1     beta_0 1.014051
#>      2:         2     1     beta_0 1.050616
#>      3:         3     1     beta_0 1.011116
#>      4:         4     1     beta_0 1.107581
#>      5:         5     1     beta_0 1.188541
#>     ---                                    
#> 689996:      9996     3 y_pred[10] 2.715669
#> 689997:      9997     3 y_pred[10] 3.107364
#> 689998:      9998     3 y_pred[10] 3.197208
#> 689999:      9999     3 y_pred[10] 3.302505
#> 690000:     10000     3 y_pred[10] 2.629052

pars_dt <- tidy(samples, 
                conf_level = 0.95, 
                colnames = c("beta_0", 
                             "beta_1", 
                             "tau_y"))

pars_dt
#>    parameter       mean        sd      2.5%     median     97.5%
#>       <fctr>      <num>     <num>     <num>      <num>     <num>
#> 1:    beta_0  0.9271108 0.1000431 0.7290988  0.9275133  1.123403
#> 2:    beta_1  2.2195920 0.1838708 1.8596954  2.2193082  2.582843
#> 3:     tau_y 13.8293627 4.6515008 6.2374860 13.2937580 24.299307

Summarise the line of best fit, mu, and the predictions, y_pred,

mu_dt <- tidy(samples, 
              conf_level = 0.95, 
              colnames = "mu_pred")

y_dt <- tidy(samples, 
             conf_level = 0.95, 
             colnames = "y_pred")

For plotting, we add the prediction x values to these data tables.

mu_dt[ , x:= x_pred]
y_dt[ , x:= x_pred]
y_dt
#>      parameter      mean        sd      2.5%   median    97.5%         x
#>         <fctr>     <num>     <num>     <num>    <num>    <num>     <num>
#>  1:  y_pred[1] 0.9512734 0.3045908 0.3464125 0.950201 1.557490 0.0112199
#>  2:  y_pred[2] 1.1948567 0.2963282 0.6098184 1.194413 1.789029 0.1202395
#>  3:  y_pred[3] 1.4377712 0.2945579 0.8552373 1.436277 2.025528 0.2292590
#>  4:  y_pred[4] 1.6783725 0.2936560 1.0942014 1.679420 2.256380 0.3382786
#>  5:  y_pred[5] 1.9195364 0.2935423 1.3378726 1.918996 2.501872 0.4472982
#>  6:  y_pred[6] 2.1610523 0.2949844 1.5729255 2.160868 2.744390 0.5563177
#>  7:  y_pred[7] 2.4030173 0.2959932 1.8114716 2.403756 2.983131 0.6653373
#>  8:  y_pred[8] 2.6468318 0.2985019 2.0550072 2.647474 3.234796 0.7743569
#>  9:  y_pred[9] 2.8881227 0.3053736 2.2877094 2.888161 3.500584 0.8833765
#> 10: y_pred[10] 3.1317311 0.3109157 2.5166674 3.131169 3.746946 0.9923960

Now we’ll generate a plot that shows the data, a 95% credible interval for the predictions, ${\hat{\bm{y}}}_{pred}$, and a 95% credible interval for their means, ${\hat{\bm{\mu}}}_{pred}$.

If we tidy the samples object, we can look at the distribution of values

tidy_samples <- mcmc_to_dt(samples, 
                           colnames = c("beta_0", 
                                        "beta_1", 
                                        "tau_y"))

ggplot(data = tidy_samples, 
       aes(x = value)) +
    geom_density(color = "black", 
                 fill = "grey90") +
    facet_wrap(~parameter,
               nrow = 1,
               scales = "free") +
    theme_bw() +
    geom_segment(data = pars_dt,
                   aes(x = `2.5%`,
                       xend = `97.5%`),
                 y = 0, 
                 yend = 0,
                 size = 2) +
    geom_point(data = pars_dt,
               aes(x = mean),
               y = 0,
               color = "white")
#> Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
#> ℹ Please use `linewidth` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.

We can also thin to create trace plots and plot per chain

tidy_samples_10 <- thin_dt(tidy_samples, thin = 10)

ggplot(data=tidy_samples_10, aes(x=iteration, y=value)) +
    geom_line(aes(group=chain, color=factor(chain))) +
    facet_wrap( ~ parameter, ncol=1, scales="free_y") +
    theme_bw() +
    theme(legend.position = "bottom") +
    scale_color_discrete(name="Chain")