Bayesian Methods

BSMM8740-2-R-2024F [WEEK - 10]

L.L. Odette

Recap of last week

  • Last week we introduced Markov Chain methods for integration and sampling from probability distributions.
  • We also built a basic understanding of the tools for sampling in Bayesian analysis.

This week

  • We will explore Bayesian methods in greater detail, including Bayesian workflow and model comparison.
  • We will use BRMS (Bayesian regression models with Stan) one of the popular R packages for Bayesian analysis.

BRMS

  • Reference materials for BRMS can be found here and here.
  • Instructions for installing BRMS can be found here and here. The basic steps (in order) are
    1. Configure the C++ toolchain (use RTools in Windows)
    2. Install Stan and verify the Stan installation
    3. Install BRMS

Bayesian Estimation

Bayesian Estimation

A statistical model \(M\) is a model of a random process that could have generated our observable data. The observable data \(\mathcal{D}\) contains both dependent variables \(\mathcal{D}_\mathrm{DV}\) and independent variables \(\mathcal{D}_\mathrm{IV}\)

A model \(M\) for data \(\mathcal{D}\) fixes a likelihood function for \(\mathcal{D}_\mathrm{DV}\). The likelihood function often has parameters, represented by a parameter vector \(\theta\).

Using Bayes rule we can estimate the model parameters \(\theta\) from the data

\[ \underbrace{P_M(\theta \vert \mathcal{D})}_{\text{Posterior}} = \frac{1}{\underbrace{P(\mathcal{D},\theta)}_{\text{Normalization}}} \overbrace{P_M(\mathcal{D}_\mathrm{DV} \vert\mathcal{D}_\mathrm{IV}, \theta)}^{\text{Likelihood}}\overbrace{P_M(\theta)}^{\text{Prior}} \]

Maximum likelihood (MLE)

Recall that for a linear regression \(y\sim \mathcal{N}(\beta x,\sigma^2)\) the likelihood of any one observation \(y_i\) is (with \(\theta\) representing the set of parameters)

\[ \pi\left(\left.y_{i}\right|x_{i},\beta,\sigma^{2}\right)=\pi\left(\left.y_{i}\right|x_{i},\theta\right)=\frac{1}{\sqrt{2\pi\sigma^{2}}}e^{-\frac{(y_{i}-\beta x_{i})^{2}}{2\sigma^{2}}} \] and the log-likelihood of \(N\) observations \(\{y_i\}_{i=1}^N\) is

\[ \log\prod_{i=1}^{N}\pi\left(\left.y_{i}\right|x_{i},\theta\right) = \sum_{i=1}^{N}\log \pi\left(\left.y_{i}\right|x_{i},\theta\right) \]

Maximum likelihood (MLE)

The maximum likelihood estimate of \(\beta\) is

\[ \hat{\theta}_{\text{MLE}}=\arg\max_{\theta} -\sum_{i=1}^{N}\log \pi\left(\left.y_{i}\right|x_{i},\theta\right) \]

\[ \log\prod_{i=1}^{N}\pi\left(\left.y_{i}\right|x_{i},\theta\right) = \sum_{i=1}^{N}\log \pi\left(\left.y_{i}\right|x_{i},\theta\right) \]

this is equivalent to minimizing the sum of the squared errors, and is also called the a priori estimate.

Bayesian model

The Bayesian model for linear regression is (to within a scaling constant)

\[ \begin{align*} \pi_{\theta\vert\mathcal{D}}\left(\left.\theta\right|y_i,x_i\right)\sim\pi_{\mathcal{D}}\left(\left.y_i\right|x_i,\theta\right)\times\pi_\theta\left(\theta\right) \end{align*} \]

where the parameters are \(\theta=\{\beta,\sigma^2\}\).

In words: the joint probability of the parameters given the observed volume data is equal to (to within a scaling constant) the probability of the observed volume data given the parameters, times the prior probabilities of the parameters. In practice we refer to the probabilities as likelihoods, and use log-likelihoods to avoid numerical problems arising from the product of small probabilities.

Max a posteriori estimate (MAPE)

The maximum a posteriori estimate of the parameters is

\[ \begin{align*} \hat{\theta}_{\text{MAP}} & =\arg\max_{\theta}\log\prod_{i=1}^{N}\pi_{\theta\vert\mathcal{D}}\left(\left. \theta \right|y_{i},x_{i}\right)\\ & =\arg\max_{\theta}\sum_{i=1}^{N} \left(\log \pi\left(\left.y_{i}\right|x_{i},\theta\right)+\log\pi_\theta\left(\theta\right)\right)\\ & =\arg\min_{\theta}-\sum_{i=1}^{N} \left(\log \pi\left(\left.y_{i}\right|x_{i},\theta\right)+\log\pi_\theta\left(\theta\right)\right) \end{align*} \]

Max a posteriori estimate (MAPE)

If \(\theta\) is not uncertain/random, then \(\pi(\theta)=1\rightarrow\log\pi(\theta)=0\) and the MAPE is equal to the MLE.

In linear regression we assume \(\pi(\sigma^2) = 1\), so If \(\theta\) is uncertain/random it remains to give a prior distribution to \(\beta\) (as a vector of dimension \(D\), in general). Assume \(\beta=\mathscr{N}\left(0,\lambda^{-1}I\right)\) (with a single scale constant \(\lambda\)), then

\[ \pi_\theta(\beta) = \frac{1}{\sqrt{(2\pi)^D \frac{1}{\lambda^D}}}exp(-\frac{1}{2}(\beta - 0)^\top (\frac{1}{\lambda} I)^{-1} (\beta - 0)) = \frac{\lambda^{\frac{D}{2}}}{(2\pi)^{\frac{D}{2}}}exp(-\frac{\lambda}{2} \beta^\top \beta) \]

Max a posteriori estimate (MAPE)

With Gaussian \(\pi(\beta)\) as in the last slide, and likelihood

\[ \pi(y_i \vert x_i, \beta) = \frac{1}{\sqrt{2\pi\sigma^2}}exp(-\frac{1}{2\sigma^2}(y_i- x_i^\top\beta)^2) \]

we have, for linear regression

\[ \begin{align*} \hat{\theta}_{\text{MAP}} & =\arg\min_{\theta}-\sum_{i=1}^{N}\left(\log\pi\left(\left.y_{i}\right|x_{i},\theta\right)+\log\pi_\theta\left(\theta\right)\right)\\ & =\arg\min_{\theta}\left(\frac{1}{2\sigma^{2}}\sum_{i=1}^{N}(y_i-x_i^{T}\beta)^{2}+\frac{\lambda}{2}\beta^\top\beta\right) \end{align*} \]

which turns out to be a linear interpolation between the prior mean and the sample mean weighted by their respective covariances.

Max a posteriori estimate (MAPE)

In this MAPE for linear regression, with Gaussian priors, the posterior is also a Gaussian, since the product of Gaussian distributions is proportional to a Gaussian distribution, and the denominator in Bayes rule reflects the proportionality.

However, this a special case where the likelihood and prior distributions are conjugate.

Conjugate priors

If the posterior distribution \(\pi_{\theta\vert\mathcal{D}}\) is in the same probability distribution family as the prior probability distribution \(\pi_\theta\) (generally this means they are the same to within a normalizing constant), the prior and posterior are then called conjugate distributions, and the prior is called a conjugate prior for the likelihood function \(\pi_{\mathcal{D}}\).

A conjugate prior is an algebraic convenience, giving a closed-form expression for the posterior.

Conjugate priors

example 1

Consider a random variable which consists of the number of successes \(s\) in \(n\) Bernoulli trials with unknown probability of success \(p\in[0,1]\). This random variable will follow the binomial distribution, with a probability mass function of the form

\[ \pi(s\vert p)={n \choose s}p^s(1-p)^{n-s} \]

Conjugate priors

example 1

The usual conjugate prior for the Bernoulli is the beta distribution with parameters (\(\alpha, \beta\)):

\[ \pi_\theta(p;\alpha,\beta) = \frac{p^{\alpha-1}(1-p)^{\beta-1}}{\mathrm{B}(\alpha,\beta)} \]

where \(\alpha\) and \(\beta\) are chosen to reflect any existing belief or information (\(\alpha = 1\) and \(\beta = 1\) would give a uniform distribution) and \(\mathrm{B}(\alpha,\beta)\) is the Beta function acting as a normalising constant.

Conjugate priors

example 1

If we sample this random variable and get \(s'\) successes and \(f=n-s'\) failures, then we have

\[ \begin{align*} \pi_{\theta\vert\mathcal{D}}(p=x) & \sim x^{s'}(1-x)^{n-s'}\times x^{\alpha-1}(1-x)^{\beta-1}\\ & \sim x^{s'+\alpha-1}(1-x)^{(n-s')+\beta-1}\\ & \sim\pi_{\theta\vert\mathcal{D}}(x;s'+\alpha-1,(n-s')+\beta-1) \end{align*} \]

And the posterior distribution \(\pi_{\theta\vert\mathcal{D}}\) is in the same probability distribution family as the prior probability distribution \(\pi_\theta\).

Conjugate priors

example 2

Suppose you’ve been asked to find the probability that you have exactly 5 outages at your website during any hour of the day. Your client has limited data, in fact they have just three data points \(y=[3,4,1]\)

If you assume that the data are generated by a Poisson distribution (which has a single parameter, the rate \(\lambda\)), then the maximum likelihood estimate of \(\lambda\) is \(\lambda=\frac{3+4+1}{3}\approx 2.67\), and you would estimate the probability as:

\[ \pi(n=5\vert\lambda\approx 2.67) = \frac{\lambda^n e^{-\lambda}}{n!}=\frac{2.67^5 e^{-2.67}}{5!}=0.078 \]

Conjugate priors

example 2

We’ve assumed that the observed data \(y\) is most likely to have been generated by a Poisson distribution with MLE for \(\lambda= 2.67\).

But the data could also have come from another Poisson distribution, e.g., one with \(\lambda =3\), or \(\lambda =2\), etc. In fact, there is an infinite number of Poisson distributions that could have generated the observed data.

With relatively few data points, we should be quite uncertain about which exact Poisson distribution generated this data. Intuitively we should instead take a weighted average of the probability of \(\pi(y\ge 0|\lambda )\) for each of those Poisson distributions, weighted by how likely they each are, given the data we’ve observed.

This is exactly what Bayes’ Rule does.

Conjugate priors

example 2

Luckily, the Poisson distribution has a conjugate, the Gamma distribution:

\[ \pi_\theta(x\lambda;\alpha,\beta)=\frac{x^{\alpha-1}e^{-\beta \lambda}\beta^\alpha}{\Gamma(\alpha)} \]

and

\[ \begin{align*} \pi\left(y\vert\lambda\right) & =\prod_{i=1}^{n}\frac{\lambda^{y_{i}}e^{-\lambda}}{y_{i}!}\\ & =\lambda^{n\bar{y}}e^{-n\lambda}\prod_{i=1}^{n}\frac{1}{y_{i}!} \end{align*} \]

so \(\pi\left(y\vert\lambda\right)\times \pi_\theta(x;\alpha,\beta)\sim \lambda^{n\bar{y}+\alpha-1}e^{-(n+\beta)\lambda}\sim\pi_\theta(\lambda;n\bar{y}+\alpha,(n+\beta))\)

Conjugate priors

example 2

Given our observations \(\lambda=\frac{3+4+1}{3}\approx 2.67\), we might arbitrarily take the prior as a Gamma with \(\alpha=9;\;\beta = 2\) (mean \(\alpha/\beta\)) so that the prior and posterior look like this:

Code
.shape <- 9 + 3*2.67  ; .rate <- 3+2

tibble::tibble(lambda = seq(0.04,15,0.02), plambda = dgamma(seq(0.04,15,0.02), shape=9, rate = 2), measure = "prior") |> 
  dplyr::bind_rows(
    tibble::tibble(lambda = seq(0.04,15,0.02), plambda = dgamma(seq(0.04,15,0.02), shape=.shape, rate = .rate), measure = "posterior")
  ) |> 
  ggplot(aes(x=lambda, y = plambda, color = measure)) + geom_line() + 
  labs(title = "Probability distributions for Lambda", subtitle = " prior and posterior predictive") +
  theme(legend.position = "right")

Code
.shape <- 9 + 3*2.67  ; .rate <- 3+2
ci <- qgamma(c(0.05,0.95), shape=.shape, rate = .rate) |> 
  purrr::map_dbl(function(x){round( x^5 * exp(-x)/factorial(5),digits=3)})

Given the posterior hyperparameters, we can finally compute the posterior predictive distribution (\(\alpha=9+3\times2.67\), \(\beta=3\times2\), \(\mu=2.835\)) and estimate the 90% confidence intervals for the probability as 0.046, 0.175. This much more conservative estimate reflects the uncertainty in the model parameters, which the posterior predictive takes into account.

Conjugate priors

  • Conjugate priors offer ease of computation, efficiency in updates, and clear interpretation, making them suitable for simpler models or real-time applications.

  • However, they are often too restrictive for complex or non-standard models, where flexibility in capturing prior beliefs is crucial.

Conjugate priors

Limitations of Conjugate Priors

  • Restrictive Choice of Priors: Conjugate priors limit the choice of prior distributions to a specific family. This restricts flexibility, especially if real-world data suggests a prior belief outside the conjugate family, which may not accurately capture prior knowledge or uncertainty.
  • Lack of Flexibility with Complex Models: Conjugate priors are often insufficient for complex models, such as hierarchical or multi-level models, where dependencies between variables require more flexible priors. Non-conjugate priors, despite being more computationally intensive, can better accommodate the complexity of these models.
  • Potential for Over-Simplification: Choosing a conjugate prior for convenience can sometimes lead to oversimplification, especially if it does not match the true prior knowledge. This can introduce bias and reduce the model’s accuracy in reflecting genuine prior beliefs.
  • Less Suitable for Non-Standard Likelihoods: Conjugate priors work best with specific likelihood functions, and many real-world problems don’t have standard likelihoods that match conjugate prior forms. In these cases, using a conjugate prior may be infeasible, forcing the use of non-conjugate methods.

Generative (bayesian) modelling

Generative (bayesian) modelling

  • Generative Bayesian modeling is an approach in Bayesian statistics where we create models that describe how data is generated, often by specifying probability distributions for both observed and latent (unobserved) variables/parameters.

  • This process involves defining a generative process — a step-by-step probabilistic framework that models how data could have arisen.

Generative (bayesian) modelling

Note the similarity to DAGs:

Given a DAG, we next assign probability distributions to each node/variable, relate the nodes through the parameters of the distributions and finally assign priors for any remaining/undetermined parameters, including parameters used to define the relationships between variables.

Generative (bayesian) modelling

Key Components of Generative Bayesian Modeling

  • Defining Priors: Start by assigning prior distributions to parameters, reflecting prior beliefs about these parameters before observing data. Priors incorporate domain knowledge and regularize the model.
  • Likelihood Function: Specify the likelihood, which represents the probability of observing the data given the parameters. It describes how data is assumed to be generated given specific values of the model parameters.
  • Posterior Inference: Using Bayes’ theorem, combine priors and the likelihood to calculate the posterior distribution of the parameters. This posterior reflects updated beliefs after seeing the data.
  • Latent Variables and Hierarchies: Generative models can include latent variables, which represent unobserved or hidden factors, and hierarchical structures, which model data with multiple levels of variation (e.g., nested or grouped data).

Generative (bayesian) modelling

Advantages of Generative Bayesian Models

  • Interpretability: Generative models explicitly describe the data-generating process, making them interpretable and suitable for understanding complex systems.
  • Predictive Power: By learning the underlying structure of the data, generative models can predict unseen outcomes and infer missing or latent data.
  • Uncertainty Quantification: Bayesian models naturally quantify uncertainty in parameter estimates and predictions, enhancing decision-making with probabilistic insights.

Generative (bayesian) modelling

Limitations

  • Computational Complexity: Generative Bayesian models, especially those with complex hierarchical structures or latent variables, can be computationally demanding, often requiring methods like MCMC.
  • Model Specification: The accuracy of generative Bayesian models heavily depends on correctly specifying the generative process, which can be challenging with limited domain knowledge or complex data.

Applications

Generative Bayesian modeling is widely used in areas requiring a deep understanding of data-generating processes, such as healthcare, natural language processing, finance, and other business applications. It enables tasks like anomaly detection, missing data imputation, and causal inference by modeling the probability structure of observed and unobserved variables.

BRMS

Generative modelling with BRMS

BRMS (Bayesian Regression Models using Stan) is an R package for fitting complex Bayesian regression models using Stan, a powerful probabilistic programming language. BRMS provides a high-level, formula-based interface in R, making it easy to specify and fit Bayesian models.

BRMS is useful for performing complex Bayesian analyses in R without diving into raw Stan code.

We’ll start with a very simple example.

Generative modelling with BRMS

Basic model: manufacturing failures per \(N\) units

We express the likelihood for our coin toss example as

\[y_{i} \sim \operatorname{Bernoulli}(\theta)\]

and our prior will be

\[\theta \sim \operatorname{Beta}(\alpha, \beta)\]

Code
dat <- readr::read_csv("data/z15N50.csv", show_col_types = FALSE)

dat |> 
  dplyr::mutate(y = y |> as.character()) |> 
  ggplot(aes(x = y)) +
  geom_bar() +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) +
  theme_minimal(base_size = 36)

Generative modelling with BRMS

Basic model: Prior

  • uninformative : \(\theta_c\sim\mathrm{Beta}(1,1)\)
  • weakly informative : \(\theta_c\sim\mathrm{Beta}(5,2)\)
  • strongly informative : \(\theta_c\sim\mathrm{Beta}(50,20)\)
  • point-valued : \(\theta_c\sim\mathrm{Beta}(\alpha,\beta)\) with \(\alpha,\beta\rightarrow\infty\) and \(\alpha,\beta=52\)

Generative modelling with BRMS

Basic model: manufacturing failures per \(N\) units

fit the model
fit8.1 <-
  brms::brm(data = dat, 
      family = brms::bernoulli(link = identity),
      formula = y ~ 1,
      brms::prior(beta(2, 2), class = Intercept, lb = 0, ub = 1),
      iter = 500 + 3334, warmup = 500, chains = 3,
      seed = 8,
      file = "fits/fit08.01")
plot(fit8.1)

Generative modelling with BRMS

Basic model: manufacturing failures per \(N\) units

print(fit8.1)
 Family: bernoulli 
  Links: mu = identity 
Formula: y ~ 1 
   Data: my_data (Number of observations: 50) 
  Draws: 3 chains, each with iter = 3834; warmup = 500; thin = 1;
         total post-warmup draws = 10002

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat
Intercept     0.31      0.06     0.20     0.44 1.00
          Bulk_ESS Tail_ESS
Intercept     3865     4706

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
brms::posterior_summary(fit8.1, robust = T)
            Estimate Est.Error      Q2.5    Q97.5
b_Intercept   0.3127   0.06393   0.19932   0.4415
Intercept     0.3127   0.06393   0.19932   0.4415
lprior        0.2542   0.10789  -0.04339   0.3917
lp__        -32.0781   0.31863 -34.34499 -31.8453

Generative modelling with BRMS

Basic model: manufacturing failures per \(N\) units

extract the draws
draws <- brms::as_draws_df(fit8.1) 
draws
# A draws_df: 3334 iterations, 3 chains, and 4 variables
   b_Intercept Intercept lprior lp__
1         0.30      0.30  0.239  -32
2         0.34      0.34  0.294  -32
3         0.22      0.22  0.021  -33
4         0.29      0.29  0.209  -32
5         0.28      0.28  0.191  -32
6         0.37      0.37  0.334  -32
7         0.30      0.30  0.232  -32
8         0.32      0.32  0.263  -32
9         0.32      0.32  0.264  -32
10        0.27      0.27  0.168  -32
# ... with 9992 more draws
# ... hidden reserved variables {'.chain', '.iteration', '.draw'}
plot the draws by chain
draws |> 
  dplyr::mutate(chain = .chain) |> 
  bayesplot::mcmc_dens_overlay(pars = vars(b_Intercept)) 

plot acf by chain
draws |> 
  dplyr::mutate(chain = .chain) |> 
  bayesplot::mcmc_acf(pars = vars(b_Intercept), lags = 35) +
  theme_minimal()

Generative modelling with BRMS

Basic model: manufacturing failures per \(N\) units

dat |> 
  dplyr::mutate(y = y |> as.character()) |> 
  ggplot(aes(x = y, fill = s)) +
  geom_bar(show.legend = F) +
  ggthemes::scale_fill_colorblind() +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) +
  theme_minimal() +
  facet_wrap(~ s)

Generative modelling with BRMS

Basic model: manufacturing failures per \(N\) units

fit the two-site model
fit8.2 <-
  brms::brm(data = dat, 
      family = brms::bernoulli(identity),
      y ~ 0 + s,
      brms::prior(beta(2, 2), class = b, lb = 0, ub = 1),
      iter = 2000, warmup = 500, cores = 4, chains = 4,
      seed = 8,
      file = "fits/fit08.02")
plot chains for both sites
plot(fit8.2, widths = c(2, 3))

Generative modelling with BRMS

Basic model: manufacturing failures per \(N\) units

summary(fit8.2)
 Family: bernoulli 
  Links: mu = identity 
Formula: y ~ 0 + s 
   Data: dat (Number of observations: 15) 
  Draws: 4 chains, each with iter = 2000; warmup = 500; thin = 1;
         total post-warmup draws = 6000

Regression Coefficients:
         Estimate Est.Error l-95% CI u-95% CI Rhat
sLondon      0.37      0.14     0.12     0.66 1.00
sWindsor     0.67      0.13     0.38     0.88 1.00
         Bulk_ESS Tail_ESS
sLondon      5546     3535
sWindsor     5034     3757

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Code
pairs(fit8.2,
      off_diag_args = list(size = 1/3, alpha = 1/3))

Code
draws <- brms::as_draws_df(fit8.2)

draws <-
  draws |> 
  dplyr::rename(theta_Windsor = b_sWindsor, theta_London  = b_sLondon) |> 
  dplyr::mutate(`theta_Windsor - theta_London` = theta_Windsor - theta_London)

long_draws <-
  draws |> 
  dplyr::select(starts_with("theta")) |> 
  tidyr::pivot_longer(everything()) |> 
  dplyr::mutate(name = factor(name, levels = c("theta_Windsor", "theta_London", "theta_Windsor - theta_London"))) 
Warning: Dropping 'draws_df' class as required
metadata was removed.
Code
long_draws |> 
  ggplot(aes(x = value, y = 0, fill = name)) +
  tidybayes::stat_histinterval(point_interval = tidybayes::mode_hdi, .width = .95,
                    slab_color = "white", outline_bars = T,
                    normalize = "panels") +
  scale_fill_manual(values = ggthemes::colorblind_pal()(8)[2:4], breaks = NULL) +
  scale_y_continuous(NULL, breaks = NULL) +
  theme_minimal() +
  facet_wrap(~ name, scales = "free")

Code
long_draws |> 
  dplyr::group_by(name) |> 
  tidybayes::mode_hdi()
# A tibble: 3 × 7
  name     value  .lower .upper .width .point .interval
  <fct>    <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
1 theta_W… 0.698  0.411   0.902   0.95 mode   hdi      
2 theta_L… 0.314  0.105   0.637   0.95 mode   hdi      
3 theta_W… 0.314 -0.0824  0.649   0.95 mode   hdi      

Generative modelling with BRMS

Basic model: manufacturing failures per \(N\) units

separate theta estimates
fit8.3 <-
  brms::brm(data = dat, 
      family = brms::bernoulli(identity),
      y ~ 0 + s,
      prior =
        c(brms::prior(beta(2, 2), class = b, coef = sWindsor),
          brms::prior(beta(2, 2), class = b, coef = sLondon),
          # this just sets the lower and upper bounds
          brms::prior(beta(2, 2), class = b, lb = 0, ub = 1)),
      iter = 2000, warmup = 500, cores = 4, chains = 4,
      sample_prior = "only",
      seed = 8,
      file = "fits/fit08.03")
separate theta estimates
draws <- brms::as_draws_df(fit8.3) |> 
  dplyr::select(starts_with("b_"))
Warning: Dropping 'draws_df' class as required
metadata was removed.
separate theta estimates
# dat |> 
#   dplyr::group_by(s) |> 
#   dplyr::summarise(z = sum(y), N = dplyr::n()) |> 
#   dplyr::mutate(`z/N` = z / N)

levels <- c("theta_Windsor", "theta_London", "theta_Windsor - theta_London")
d_line <-
  tibble::tibble(value = c(.75, .286, .75 - .286),
         name  =  factor(c("theta_Windsor", "theta_London", "theta_Windsor - theta_London"), 
                         levels = levels))

draws |> 
  dplyr::rename(theta_Windsor = b_sWindsor,
         theta_London  = b_sLondon) |> 
  dplyr::mutate("theta_Windsor - theta_London" = theta_Windsor - theta_London) |> 
  tidyr::pivot_longer(contains("theta")) |> 
  dplyr::mutate(name = factor(name, levels = levels)) |>
  
  ggplot(aes(x = value, y = 0)) +
  tidybayes::stat_histinterval(point_interval = tidybayes::mode_hdi, .width = .95,
                    fill = ggthemes::colorblind_pal()(8)[5], normalize = "panels") +
  geom_vline(data = d_line, 
             aes(xintercept = value), 
             linetype = 2) +
  scale_y_continuous(NULL, breaks = NULL) +
  labs(subtitle = expression("The dashed vertical lines mark off "*italic(z[s])/italic(N[s]))) +
  cowplot::theme_cowplot() +
  facet_wrap(~ name, scales = "free")

method for separate theta estimates
draws |> 
  dplyr::rename(theta_Windsor = b_sWindsor,
         theta_London  = b_sLondon) |> 
  
  ggplot(aes(x = theta_Windsor, y = theta_London)) +
  geom_point(alpha = 1/4, color = ggthemes::colorblind_pal()(8)[6]) +
  coord_equal() +
  cowplot::theme_minimal_grid()

Generative modelling with BRMS

Recall from the last chapter that our likelihood is the Bernoulli distribution,

\[y_i \sim \operatorname{Bernoulli}(\theta).\]

We’ll use the beta density for our prior distribution for \(\theta\),

\[\theta \sim \operatorname{Beta}(\alpha, \beta).\]

And we can re-express \(\alpha\) and \(\beta\) in terms of the mode \(\omega\) and concentration \(\kappa\), such that

\[\alpha = \omega(\kappa - 2) + 1 \;\;\; \textrm{and} \;\;\; \beta = (1 - \omega)(\kappa - 2) + 1.\]

As a consequence, we can re-express \(\theta\) as

\[\theta \sim \operatorname{Beta}(\omega(\kappa - 2) + 1, (1 - \omega)(\kappa - 2) + 1).\]

The value of \(\kappa\) governs how near \(\theta\) is to \(\omega\), with larger values of \(\kappa\) generating values of \(\theta\) more concentrated near \(\omega\).

Generative modelling with BRMS

Using \(s\) for shape and \(r\) for rate are as follows:

\[ \begin{align*} s & =\frac{\mu^{2}}{\sigma^{2}}\;\;\;\text{and}\;\;\;r=\frac{\mu}{\sigma^{2}}\;\;\;\text{for mean}\;\;\;\mu>0\\ s & =1+\omega r\;\;\;\text{where}\;\;\;r=\frac{\omega+\sqrt{\omega^{2}+4\sigma^{2}}}{2\sigma^{2}}\;\;\;\text{for mode}\;\;\;\omega>0 \end{align*} \]

The value of \(\kappa\) governs how near \(\theta\) is to \(\omega\), with larger values of \(\kappa\) generating values of \(\theta\) more concentrated near \(\omega\).

re-parameterization functions
gamma_s_and_r_from_mean_sd <- function(mean, sd) {
  if (mean <= 0) stop("mean must be > 0")
  if (sd   <= 0) stop("sd must be > 0")
  shape <- mean^2 / sd^2
  rate  <- mean   / sd^2
  return(list(shape = shape, rate = rate))
}

gamma_s_and_r_from_mode_sd <- function(mode, sd) {
  if (mode <= 0) stop("mode must be > 0")
  if (sd   <= 0) stop("sd must be > 0")
  rate  <- (mode + sqrt(mode^2 + 4 * sd^2)) / (2 * sd^2)
  shape <- 1 + mode * rate
  return(list(shape = shape, rate = rate))
}

Generative modelling with BRMS

multiple sites

dat <- readr::read_csv("data/TherapeuticTouchData.csv", show_col_types = FALSE)
Code
dat |> 
  dplyr::mutate(y = y |> as.character()) |> 
  
  ggplot(aes(y = y)) +
  geom_bar(aes(fill = after_stat(count))) +
  scale_fill_viridis_c(option = "A", end = .7, breaks = NULL) +
  scale_x_continuous(breaks = 0:4 * 2, expand = c(0, NA), limits = c(0, 9)) +
  cowplot::theme_minimal_vgrid() +
  cowplot::panel_border() +
  facet_wrap(~ s, ncol = 7)

Code
a_purple <- viridis::viridis_pal(option = "A")(9)[4]
dat |> 
  dplyr::group_by(s) |> 
  dplyr::summarize(mean = mean(y)) |>
  
  ggplot(aes(x = mean)) +
  geom_histogram(color = "white", fill = a_purple,
                 linewidth = .2, binwidth = .1) +
  scale_x_continuous("Proportion Not Failing", limits = c(0, 1)) +
  scale_y_continuous("# Practitioners", expand = c(0, NA)) +
  cowplot::theme_minimal_hgrid()

Generative modelling with BRMS

fit multiple sites | hierarchical

Code
fit9.1 <-
  brms::brm(data = dat,
      family = brms::bernoulli(link = logit),
      y ~ 1 + (1 | s),
      prior = c(brms::prior(normal(0, 1.5), class = Intercept),
                brms::prior(normal(0, 1), class = sd)),
      iter = 20000, warmup = 1000, thin = 10, chains = 4, cores = 4,
      seed = 9,
      file = "fits/fit09.01")
print(fit9.1)
 Family: bernoulli 
  Links: mu = logit 
Formula: y ~ 1 + (1 | s) 
   Data: dat (Number of observations: 280) 
  Draws: 4 chains, each with iter = 20000; warmup = 1000; thin = 10;
         total post-warmup draws = 7600

Multilevel Hyperparameters:
~s (Number of levels: 28) 
              Estimate Est.Error l-95% CI u-95% CI
sd(Intercept)     0.28      0.18     0.01     0.68
              Rhat Bulk_ESS Tail_ESS
sd(Intercept) 1.00     7303     7351

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat
Intercept    -0.25      0.13    -0.52     0.01 1.00
          Bulk_ESS Tail_ESS
Intercept     7665     7348

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Code
plot(fit9.1, widths = c(2, 3))

population parameter estimates (log-odds)
Code
draws <- brms::as_draws_df(fit9.1)
draws |> 
  dplyr::mutate(chain = .chain) |> 
  bayesplot::mcmc_acf(pars = vars(b_Intercept, sd_s__Intercept), lags = 10) +
  cowplot::theme_cowplot()

acf by chain; population parameters
Code
bayesplot::neff_ratio(fit9.1) |> 
  bayesplot::mcmc_neff() +
  cowplot::theme_cowplot(font_size = 12)

effective samples by variable
Code
draws_small <-
  draws |> 
  # convert the linear model parameters to the probability space with `inv_logit_scaled()`
  dplyr::mutate(`theta[1]`  = (b_Intercept + `r_s[S01,Intercept]`) |> brms::inv_logit_scaled(),
         `theta[14]` = (b_Intercept + `r_s[S14,Intercept]`) |> brms::inv_logit_scaled(),
         `theta[28]` = (b_Intercept + `r_s[S28,Intercept]`) |> brms::inv_logit_scaled()) |> 
  # make the difference distributions
  dplyr::mutate(`theta[1] - theta[14]`  = `theta[1]`  - `theta[14]`,
         `theta[1] - theta[28]`  = `theta[1]`  - `theta[28]`,
         `theta[14] - theta[28]` = `theta[14]` - `theta[28]`) |> 
  dplyr::select(starts_with("theta"))

draws_small |> 
  tidyr::pivot_longer(everything()) |> 
  # this line is unnecessary, but will help order the plots 
  dplyr::mutate(name = factor(name, levels = c("theta[1]", "theta[14]", "theta[28]", 
                                        "theta[1] - theta[14]", "theta[1] - theta[28]", "theta[14] - theta[28]"))) |> 

  ggplot(aes(x = value, y = 0)) +
  tidybayes::stat_histinterval(point_interval = ggdist::mode_hdi, .width = .95,
                    fill = a_purple, breaks = 40, normalize = "panels") +
  scale_y_continuous(NULL, breaks = NULL) +
  xlab(NULL) +
  cowplot::theme_minimal_hgrid() +
  facet_wrap(~ name, scales = "free", ncol = 3)

selected parameters and contrasts on probability scale
Code
draws_small |> 
  tidyr::pivot_longer(everything()) |>
  dplyr::group_by(name) |> 
  tidybayes::mode_hdi(value) |> 
  gt::gt("name") |> 
  gt::fmt_number(columns=value:.width, decimals=3) |> 
  gt::tab_header(title = "Parameter values and contrasts (logit, aka log-odds)", subtitle = "sites 1,14,28") |> 
  gtExtras::gt_theme_espn()
Parameter values and contrasts (logit, aka log-odds)
sites 1,14,28
value .lower .upper .width .point .interval
theta[1] 0.422 0.207 0.522 0.950 mode hdi
theta[1] - theta[14] −0.001 −0.275 0.120 0.950 mode hdi
theta[1] - theta[28] −0.002 −0.432 −0.419 0.950 mode hdi
theta[1] - theta[28] −0.002 −0.405 0.069 0.950 mode hdi
theta[14] 0.435 0.281 0.577 0.950 mode hdi
theta[14] - theta[28] −0.001 −0.330 0.106 0.950 mode hdi
theta[28] 0.452 0.364 0.701 0.950 mode hdi
Code
color_scheme_set("purple")
bayesplot::bayesplot_theme_set(theme_default() + cowplot::theme_minimal_grid())

stats::coef(fit9.1, summary = F)$s |> 
  brms::inv_logit_scaled() |> 
  data.frame() |> 
  rename(`theta[1]`  = S01.Intercept, 
         `theta[14]` = S14.Intercept, 
         `theta[28]` = S28.Intercept) |> 
  dplyr::select(`theta[1]`, `theta[14]`, `theta[28]`) |> 
  bayesplot::mcmc_pairs(off_diag_args = list(size = 1/8, alpha = 1/8)) 

pairs plot on probability scale

Generative modelling with BRMS

Shrinkage in hierarchical models

“In typical hierarchical models, the estimates of low-level parameters are pulled closer together than they would be if there were not a higher-level distribution. This pulling together is called shrinkage of the estimates”

Further,

“shrinkage is a rational implication of hierarchical model structure, and is (usually) desired by the analyst because the shrunken parameter estimates are less affected by random sampling noise than estimates derived without hierarchical structure. Intuitively, shrinkage occurs because the estimate of each low-level parameter is influenced from two sources: (1) the subset of data that are directly dependent on the low-level parameter, and (2) the higher-level parameters on which the low-level parameter depends. The higher- level parameters are affected by all the data, and therefore the estimate of a low-level parameter is affected indirectly by all the data, via their influence on the higher-level parameters.”1

Generative modelling with BRMS

Shrinkage in hierarchical models

multilevel shrinkage
dat |> 
  group_by(s) |> 
  summarise(p = mean(y)) |> 
  mutate(theta = coef(fit9.1)$s[, 1, "Intercept"] |> inv_logit_scaled()) |> 
  pivot_longer(-s) |> 
  # add a little jitter to reduce the overplotting
  mutate(value = value + runif(n = n(), min = -0.02, max = 0.02),
         name  = if_else(name == "p", "italic(z/N)", "theta")) |> 

  ggplot(aes(x = value, y = name, group = s)) +
  geom_point(color = alpha(a_purple, 1/2)) +
  geom_line(linewidth = 1/3, alpha = 1/3) +
  scale_x_continuous(breaks = 0:5 / 5, expand = c(0.01, 0.01), limits = 0:1) +
  scale_y_discrete(NULL, labels = ggplot2:::parse_safe) +
  labs(title = "Multilevel shrinkage in model",
       x = "data proportion or theta value") +
  cowplot::theme_minimal_hgrid() +
  cowplot::panel_border() 

Generative modelling with BRMS

speeding up model fits

Full Bayesian inference is a computationally very demanding task and often we need to run our models faster in shorter walltime (elapsed real time). With modern computers we have multiple processors available on a given machine such that the use of running the inference in parallel will shorten the overall walltime.

While between-chain parallelization is straightforward by merely launching multiple chains at the same time, the use of within-chain parallelization is more complicated in various ways. This vignette is an introduction to within-chain parallelization with brms.

Generative modelling with BRMS

regression models

Code
# load data
data("epilepsy", package = "brms")

epilepsy
    Age Base Trt patient visit count obs     zAge
1    31   11   0       1     1     5   1  0.42500
2    30   11   0       2     1     3   2  0.26528
3    25    6   0       3     1     2   3 -0.53327
4    36    8   0       4     1     4   4  1.22355
5    22   66   0       5     1     7   5 -1.01241
6    29   27   0       6     1     5   6  0.10557
7    31   12   0       7     1     6   7  0.42500
8    42   52   0       8     1    40   8  2.18182
9    37   23   0       9     1     5   9  1.38326
10   28   10   0      10     1    14  10 -0.05414
11   36   52   0      11     1    26  11  1.22355
12   24   33   0      12     1    12  12 -0.69299
13   23   18   0      13     1     4  13 -0.85270
14   36   42   0      14     1     7  14  1.22355
15   26   87   0      15     1    16  15 -0.37356
16   26   50   0      16     1    11  16 -0.37356
17   28   18   0      17     1     0  17 -0.05414
18   31  111   0      18     1    37  18  0.42500
19   32   18   0      19     1     3  19  0.58471
20   21   20   0      20     1     3  20 -1.17212
21   29   12   0      21     1     3  21  0.10557
22   21    9   0      22     1     3  22 -1.17212
23   32   17   0      23     1     2  23  0.58471
24   25   28   0      24     1     8  24 -0.53327
25   30   55   0      25     1    18  25  0.26528
26   40    9   0      26     1     2  26  1.86240
27   19   10   0      27     1     3  27 -1.49154
28   22   47   0      28     1    13  28 -1.01241
29   18   76   1      29     1    11  29 -1.65125
30   32   38   1      30     1     8  30  0.58471
31   20   19   1      31     1     0  31 -1.33183
32   30   10   1      32     1     3  32  0.26528
33   18   19   1      33     1     2  33 -1.65125
34   24   24   1      34     1     4  34 -0.69299
35   30   31   1      35     1    22  35  0.26528
36   35   14   1      36     1     5  36  1.06384
37   27   11   1      37     1     2  37 -0.21385
38   20   67   1      38     1     3  38 -1.33183
39   22   41   1      39     1     4  39 -1.01241
40   28    7   1      40     1     2  40 -0.05414
41   23   22   1      41     1     0  41 -0.85270
42   40   13   1      42     1     5  42  1.86240
43   33   46   1      43     1    11  43  0.74442
44   21   36   1      44     1    10  44 -1.17212
45   35   38   1      45     1    19  45  1.06384
46   25    7   1      46     1     1  46 -0.53327
47   26   36   1      47     1     6  47 -0.37356
48   25   11   1      48     1     2  48 -0.53327
49   22  151   1      49     1   102  49 -1.01241
50   32   22   1      50     1     4  50  0.58471
51   25   41   1      51     1     8  51 -0.53327
52   35   32   1      52     1     1  52  1.06384
53   21   56   1      53     1    18  53 -1.17212
54   41   24   1      54     1     6  54  2.02211
55   32   16   1      55     1     3  55  0.58471
56   26   22   1      56     1     1  56 -0.37356
57   21   25   1      57     1     2  57 -1.17212
58   36   13   1      58     1     0  58  1.22355
59   37   12   1      59     1     1  59  1.38326
60   31   11   0       1     2     3  60  0.42500
61   30   11   0       2     2     5  61  0.26528
62   25    6   0       3     2     4  62 -0.53327
63   36    8   0       4     2     4  63  1.22355
64   22   66   0       5     2    18  64 -1.01241
65   29   27   0       6     2     2  65  0.10557
66   31   12   0       7     2     4  66  0.42500
67   42   52   0       8     2    20  67  2.18182
68   37   23   0       9     2     6  68  1.38326
69   28   10   0      10     2    13  69 -0.05414
70   36   52   0      11     2    12  70  1.22355
71   24   33   0      12     2     6  71 -0.69299
72   23   18   0      13     2     4  72 -0.85270
73   36   42   0      14     2     9  73  1.22355
74   26   87   0      15     2    24  74 -0.37356
75   26   50   0      16     2     0  75 -0.37356
76   28   18   0      17     2     0  76 -0.05414
77   31  111   0      18     2    29  77  0.42500
78   32   18   0      19     2     5  78  0.58471
79   21   20   0      20     2     0  79 -1.17212
80   29   12   0      21     2     4  80  0.10557
81   21    9   0      22     2     4  81 -1.17212
82   32   17   0      23     2     3  82  0.58471
83   25   28   0      24     2    12  83 -0.53327
84   30   55   0      25     2    24  84  0.26528
85   40    9   0      26     2     1  85  1.86240
86   19   10   0      27     2     1  86 -1.49154
87   22   47   0      28     2    15  87 -1.01241
88   18   76   1      29     2    14  88 -1.65125
89   32   38   1      30     2     7  89  0.58471
90   20   19   1      31     2     4  90 -1.33183
91   30   10   1      32     2     6  91  0.26528
92   18   19   1      33     2     6  92 -1.65125
93   24   24   1      34     2     3  93 -0.69299
94   30   31   1      35     2    17  94  0.26528
95   35   14   1      36     2     4  95  1.06384
96   27   11   1      37     2     4  96 -0.21385
97   20   67   1      38     2     7  97 -1.33183
98   22   41   1      39     2    18  98 -1.01241
99   28    7   1      40     2     1  99 -0.05414
100  23   22   1      41     2     2 100 -0.85270
101  40   13   1      42     2     4 101  1.86240
102  33   46   1      43     2    14 102  0.74442
103  21   36   1      44     2     5 103 -1.17212
104  35   38   1      45     2     7 104  1.06384
105  25    7   1      46     2     1 105 -0.53327
106  26   36   1      47     2    10 106 -0.37356
107  25   11   1      48     2     1 107 -0.53327
108  22  151   1      49     2    65 108 -1.01241
109  32   22   1      50     2     3 109  0.58471
110  25   41   1      51     2     6 110 -0.53327
111  35   32   1      52     2     3 111  1.06384
112  21   56   1      53     2    11 112 -1.17212
113  41   24   1      54     2     3 113  2.02211
114  32   16   1      55     2     5 114  0.58471
115  26   22   1      56     2    23 115 -0.37356
116  21   25   1      57     2     3 116 -1.17212
117  36   13   1      58     2     0 117  1.22355
118  37   12   1      59     2     4 118  1.38326
119  31   11   0       1     3     3 119  0.42500
120  30   11   0       2     3     3 120  0.26528
121  25    6   0       3     3     0 121 -0.53327
122  36    8   0       4     3     1 122  1.22355
123  22   66   0       5     3     9 123 -1.01241
124  29   27   0       6     3     8 124  0.10557
125  31   12   0       7     3     0 125  0.42500
126  42   52   0       8     3    21 126  2.18182
127  37   23   0       9     3     6 127  1.38326
128  28   10   0      10     3     6 128 -0.05414
129  36   52   0      11     3     6 129  1.22355
130  24   33   0      12     3     8 130 -0.69299
131  23   18   0      13     3     6 131 -0.85270
132  36   42   0      14     3    12 132  1.22355
133  26   87   0      15     3    10 133 -0.37356
134  26   50   0      16     3     0 134 -0.37356
135  28   18   0      17     3     3 135 -0.05414
136  31  111   0      18     3    28 136  0.42500
137  32   18   0      19     3     2 137  0.58471
138  21   20   0      20     3     6 138 -1.17212
139  29   12   0      21     3     3 139  0.10557
140  21    9   0      22     3     3 140 -1.17212
141  32   17   0      23     3     3 141  0.58471
142  25   28   0      24     3     2 142 -0.53327
143  30   55   0      25     3    76 143  0.26528
144  40    9   0      26     3     2 144  1.86240
145  19   10   0      27     3     4 145 -1.49154
146  22   47   0      28     3    13 146 -1.01241
147  18   76   1      29     3     9 147 -1.65125
148  32   38   1      30     3     9 148  0.58471
149  20   19   1      31     3     3 149 -1.33183
150  30   10   1      32     3     1 150  0.26528
151  18   19   1      33     3     7 151 -1.65125
152  24   24   1      34     3     1 152 -0.69299
153  30   31   1      35     3    19 153  0.26528
154  35   14   1      36     3     7 154  1.06384
155  27   11   1      37     3     0 155 -0.21385
156  20   67   1      38     3     7 156 -1.33183
157  22   41   1      39     3     2 157 -1.01241
158  28    7   1      40     3     1 158 -0.05414
159  23   22   1      41     3     4 159 -0.85270
160  40   13   1      42     3     0 160  1.86240
161  33   46   1      43     3    25 161  0.74442
162  21   36   1      44     3     3 162 -1.17212
163  35   38   1      45     3     6 163  1.06384
164  25    7   1      46     3     2 164 -0.53327
165  26   36   1      47     3     8 165 -0.37356
166  25   11   1      48     3     0 166 -0.53327
167  22  151   1      49     3    72 167 -1.01241
168  32   22   1      50     3     2 168  0.58471
169  25   41   1      51     3     5 169 -0.53327
170  35   32   1      52     3     1 170  1.06384
171  21   56   1      53     3    28 171 -1.17212
172  41   24   1      54     3     4 172  2.02211
173  32   16   1      55     3     4 173  0.58471
174  26   22   1      56     3    19 174 -0.37356
175  21   25   1      57     3     0 175 -1.17212
176  36   13   1      58     3     0 176  1.22355
177  37   12   1      59     3     3 177  1.38326
178  31   11   0       1     4     3 178  0.42500
179  30   11   0       2     4     3 179  0.26528
180  25    6   0       3     4     5 180 -0.53327
181  36    8   0       4     4     4 181  1.22355
182  22   66   0       5     4    21 182 -1.01241
183  29   27   0       6     4     7 183  0.10557
184  31   12   0       7     4     2 184  0.42500
185  42   52   0       8     4    12 185  2.18182
186  37   23   0       9     4     5 186  1.38326
187  28   10   0      10     4     0 187 -0.05414
188  36   52   0      11     4    22 188  1.22355
189  24   33   0      12     4     4 189 -0.69299
190  23   18   0      13     4     2 190 -0.85270
191  36   42   0      14     4    14 191  1.22355
192  26   87   0      15     4     9 192 -0.37356
193  26   50   0      16     4     5 193 -0.37356
194  28   18   0      17     4     3 194 -0.05414
195  31  111   0      18     4    29 195  0.42500
196  32   18   0      19     4     5 196  0.58471
197  21   20   0      20     4     7 197 -1.17212
198  29   12   0      21     4     4 198  0.10557
199  21    9   0      22     4     4 199 -1.17212
200  32   17   0      23     4     5 200  0.58471
201  25   28   0      24     4     8 201 -0.53327
202  30   55   0      25     4    25 202  0.26528
203  40    9   0      26     4     1 203  1.86240
204  19   10   0      27     4     2 204 -1.49154
205  22   47   0      28     4    12 205 -1.01241
206  18   76   1      29     4     8 206 -1.65125
207  32   38   1      30     4     4 207  0.58471
208  20   19   1      31     4     0 208 -1.33183
209  30   10   1      32     4     3 209  0.26528
210  18   19   1      33     4     4 210 -1.65125
211  24   24   1      34     4     3 211 -0.69299
212  30   31   1      35     4    16 212  0.26528
213  35   14   1      36     4     4 213  1.06384
214  27   11   1      37     4     4 214 -0.21385
215  20   67   1      38     4     7 215 -1.33183
216  22   41   1      39     4     5 216 -1.01241
217  28    7   1      40     4     0 217 -0.05414
218  23   22   1      41     4     0 218 -0.85270
219  40   13   1      42     4     3 219  1.86240
220  33   46   1      43     4    15 220  0.74442
221  21   36   1      44     4     8 221 -1.17212
222  35   38   1      45     4     7 222  1.06384
223  25    7   1      46     4     3 223 -0.53327
224  26   36   1      47     4     8 224 -0.37356
225  25   11   1      48     4     0 225 -0.53327
226  22  151   1      49     4    63 226 -1.01241
227  32   22   1      50     4     4 227  0.58471
228  25   41   1      51     4     7 228 -0.53327
229  35   32   1      52     4     5 229  1.06384
230  21   56   1      53     4    13 230 -1.17212
231  41   24   1      54     4     0 231  2.02211
232  32   16   1      55     4     3 232  0.58471
233  26   22   1      56     4     8 233 -0.37356
234  21   25   1      57     4     1 234 -1.17212
235  36   13   1      58     4     0 235  1.22355
236  37   12   1      59     4     2 236  1.38326
        zBase
1   -0.757173
2   -0.757173
3   -0.944403
4   -0.869511
5    1.302363
6   -0.158035
7   -0.719727
8    0.778117
9   -0.307820
10  -0.794619
11   0.778117
12   0.066641
13  -0.495050
14   0.403656
15   2.088731
16   0.703225
17  -0.495050
18   2.987437
19  -0.495050
20  -0.420158
21  -0.719727
22  -0.832065
23  -0.532496
24  -0.120589
25   0.890456
26  -0.832065
27  -0.794619
28   0.590887
29   1.676824
30   0.253872
31  -0.457604
32  -0.794619
33  -0.457604
34  -0.270374
35  -0.008251
36  -0.644835
37  -0.757173
38   1.339809
39   0.366210
40  -0.906957
41  -0.345266
42  -0.682281
43   0.553441
44   0.178980
45   0.253872
46  -0.906957
47   0.178980
48  -0.757173
49   4.485281
50  -0.345266
51   0.366210
52   0.029195
53   0.927902
54  -0.270374
55  -0.569942
56  -0.345266
57  -0.232927
58  -0.682281
59  -0.719727
60  -0.757173
61  -0.757173
62  -0.944403
63  -0.869511
64   1.302363
65  -0.158035
66  -0.719727
67   0.778117
68  -0.307820
69  -0.794619
70   0.778117
71   0.066641
72  -0.495050
73   0.403656
74   2.088731
75   0.703225
76  -0.495050
77   2.987437
78  -0.495050
79  -0.420158
80  -0.719727
81  -0.832065
82  -0.532496
83  -0.120589
84   0.890456
85  -0.832065
86  -0.794619
87   0.590887
88   1.676824
89   0.253872
90  -0.457604
91  -0.794619
92  -0.457604
93  -0.270374
94  -0.008251
95  -0.644835
96  -0.757173
97   1.339809
98   0.366210
99  -0.906957
100 -0.345266
101 -0.682281
102  0.553441
103  0.178980
104  0.253872
105 -0.906957
106  0.178980
107 -0.757173
108  4.485281
109 -0.345266
110  0.366210
111  0.029195
112  0.927902
113 -0.270374
114 -0.569942
115 -0.345266
116 -0.232927
117 -0.682281
118 -0.719727
119 -0.757173
120 -0.757173
121 -0.944403
122 -0.869511
123  1.302363
124 -0.158035
125 -0.719727
126  0.778117
127 -0.307820
128 -0.794619
129  0.778117
130  0.066641
131 -0.495050
132  0.403656
133  2.088731
134  0.703225
135 -0.495050
136  2.987437
137 -0.495050
138 -0.420158
139 -0.719727
140 -0.832065
141 -0.532496
142 -0.120589
143  0.890456
144 -0.832065
145 -0.794619
146  0.590887
147  1.676824
148  0.253872
149 -0.457604
150 -0.794619
151 -0.457604
152 -0.270374
153 -0.008251
154 -0.644835
155 -0.757173
156  1.339809
157  0.366210
158 -0.906957
159 -0.345266
160 -0.682281
161  0.553441
162  0.178980
163  0.253872
164 -0.906957
165  0.178980
166 -0.757173
167  4.485281
168 -0.345266
169  0.366210
170  0.029195
171  0.927902
172 -0.270374
173 -0.569942
174 -0.345266
175 -0.232927
176 -0.682281
177 -0.719727
178 -0.757173
179 -0.757173
180 -0.944403
181 -0.869511
182  1.302363
183 -0.158035
184 -0.719727
185  0.778117
186 -0.307820
187 -0.794619
188  0.778117
189  0.066641
190 -0.495050
191  0.403656
192  2.088731
193  0.703225
194 -0.495050
195  2.987437
196 -0.495050
197 -0.420158
198 -0.719727
199 -0.832065
200 -0.532496
201 -0.120589
202  0.890456
203 -0.832065
204 -0.794619
205  0.590887
206  1.676824
207  0.253872
208 -0.457604
209 -0.794619
210 -0.457604
211 -0.270374
212 -0.008251
213 -0.644835
214 -0.757173
215  1.339809
216  0.366210
217 -0.906957
218 -0.345266
219 -0.682281
220  0.553441
221  0.178980
222  0.253872
223 -0.906957
224  0.178980
225 -0.757173
226  4.485281
227 -0.345266
228  0.366210
229  0.029195
230  0.927902
231 -0.270374
232 -0.569942
233 -0.345266
234 -0.232927
235 -0.682281
236 -0.719727
Code
output <- 
  capture.output(
    fit_epi_gaussian1 <- 
      brms::brm(
        count ~ 1 + Trt, data = epilepsy, 
        silent = 2, seed = 8740, file = "fits/fit_epi_gaussian1"
      )
  )
fit_epi_gaussian1
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: count ~ 1 + Trt 
   Data: epilepsy (Number of observations: 236) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat
Intercept     8.42      1.18     6.04    10.70 1.00
Trt1         -0.58      1.64    -3.82     2.68 1.00
          Bulk_ESS Tail_ESS
Intercept     4046     2833
Trt1          4033     2967

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat
sigma    12.35      0.57    11.32    13.53 1.00
      Bulk_ESS Tail_ESS
sigma     3533     2402

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Generative modelling with BRMS

Here we see, for each parameter separately, traces of the four MCMC chains in different colors with post-warmup iterations on the x-axis and parameter values on the y-axis. These trace plots are showing ideal convergence: All chains are overlaying each other nicely, are stationary (horizontal on average), and show little autocorrelation.

Code
brms::mcmc_plot(fit_epi_gaussian1, type = "trace")

Generative modelling with BRMS

Having convinced us of convergence, at least graphically for now, we can move on to inspecting the posteriors. e.g., we can plot histograms of the posterior samples per parameter.

Code
brms::mcmc_plot(fit_epi_gaussian1, type = "hist", bins=30)

Generative modelling with BRMS

In R, we can easily perform data transformations by first extracting the posterior draws and then applying the transformation per draws in a vectorized manner. e.g., using the functionality from the posterior:: package

Code
draws <- 
  posterior::as_draws_df(fit_epi_gaussian1) |> 
  posterior::mutate_variables(
    variance = sigma^2, mu_Trt = b_Intercept + b_Trt1
  )

draws
# A draws_df: 1000 iterations, 4 chains, and 8 variables
   b_Intercept b_Trt1 sigma Intercept lprior lp__
1          7.0   2.04    12       8.1   -7.3 -934
2          9.1  -1.69    13       8.2   -7.4 -933
3          9.0  -0.77    12       8.6   -7.3 -933
4          7.5   1.15    12       8.1   -7.4 -933
5          8.5  -0.89    13       8.0   -7.4 -932
6          9.1  -2.18    12       8.0   -7.2 -933
7          8.1   0.81    13       8.5   -7.5 -933
8          8.3   1.19    12       8.9   -7.4 -934
9          7.1   1.16    12       7.7   -7.2 -933
10         9.7  -1.09    13       9.1   -7.8 -934
   variance mu_Trt
1       151    9.0
2       160    7.4
3       134    8.2
4       156    8.7
5       158    7.6
6       144    6.9
7       160    8.9
8       140    9.5
9       143    8.3
10      171    8.6
# ... with 3990 more draws
# ... hidden reserved variables {'.chain', '.iteration', '.draw'}

Generative modelling with BRMS

With mu_Trt we computed the model-implied predictions of the mean for the treatment group. For the control group, it would just be mu_Ctrl = \(\beta_0\) for this simple model.

Code
bayesplot::mcmc_hist(draws, c("variance", "mu_Trt"), bins = 30)

Generative modelling with BRMS

For more complex models, computing the predictions for different groups or, more generally, different predictor values manually becomes quite cumbersome. For this reason brms provides you with a convenient method to provide quick graphical summarizes of the model-implied predictions per predictor:

Code
ce <- brms::conditional_effects(fit_epi_gaussian1)
plot(ce)

Generative modelling with BRMS

The error bars in the last slide are representing 95% credible intervals by default, but we can change that value if we like via the prob argument. Comparing the conditional_effects plot with our manually computed posterior of mu_Trt, we see that they actually are the same.

To put the mean predictions into the context of the observed data, we can also show the data as points in the plot:

Code
plot(ce, points = TRUE)

Generative modelling with BRMS

What we do in conditional_effects by default is visualize the expected value (mean parameter) of the likelihood distribution, conditional on certain predictor values. In brms, this is done via the posterior_epred (posterior expected predictions) method. For example, we can run the code below to create expected posterior predictions for both groups. The resulting object contains the posterior draws in the rows and the different conditions (here, treatment groups) in the columns.

Code
newdata <- data.frame(Trt = c(0, 1))
pe <- brms::posterior_epred(fit_epi_gaussian1, newdata = newdata)

We also can summarize the draws, for example, via

brms::posterior_summary(pe)
     Estimate Est.Error  Q2.5 Q97.5
[1,]    8.416     1.179 6.035 10.70
[2,]    7.836     1.125 5.574 10.07

In linear models, posterior_epred directly coincides with evaluating the linear predictor \(\mu\) as exemplified above. What posterior_epred does not include is the residual uncertainty, which is represented by \(\sigma\) in our linear models.

Generative modelling with BRMS

Consider again the task of evaluating predictions for the treatment group. If we are only interested in (the posterior of) the likelihood’s mean parameter, we would compute \(\mu^{(s)}_{\mathrm{Trt}}=\beta_0^{(s)} + \beta_1^{(s)}\).

In contrast, if we are interested in prediction of hypothetical new data points \(y^{(s)}_{\mathrm{Trt}}\) from the treatment group (i.e., actual posterior predictions), we would sample

\[ y^{(s)}_{\mathrm{Trt}}\sim\mathscr{N}\left(\mu^{(s)}_{\mathrm{Trt}}, \sigma^{(s)} \right) \]

Generative modelling with BRMS

This is exactly what happens behind the scenes when we execute the code below:

Code
options(brms.plot_points = TRUE)
brms::conditional_effects(fit_epi_gaussian1, method = "posterior_predict")

Generative modelling with BRMS

We could have also done this more manually via

Code
newdata <- data.frame(Trt = c(0, 1))
pp <- brms::posterior_predict(fit_epi_gaussian1, newdata = newdata)

brms::posterior_summary(pp)
     Estimate Est.Error   Q2.5 Q97.5
[1,]    8.140     12.38 -15.93 32.63
[2,]    7.612     12.37 -16.66 32.18

Generative modelling with BRMS

We are already aware that linear regression model is not ideal for the epilepsy data. But how bad is it? As quick graphical method, we can use posterior predictive (PP) checks, where we compare the observed outcome data with the model predicted outcome data, that is, with the posterior predictions. In brms, we can perform PP-checks via:

Code
brms::pp_check(fit_epi_gaussian1)

Generative modelling with BRMS

We see that the model predictions can neither account for the strong spike of observed outcomes close to zero nor for their right-skewness. Instead, the the model also predicts a lot of negative outcomes, which is impossible in reality because we are predicting counts of epileptic seizures.

In the plot, it looks as if the observed data also had few negative values (the dark blue density going below zero) but this is just an artifact of estimating a continuous density from counts. While this PP-check type is definitely not ideal to illustrate count outcome data, it still very clearly points to the shortcomings of our linear model.

Generative modelling with BRMS

While the default PP-check was already eye-opening, there are lot of types that can further our understanding of model appropriateness. For example, an often very useful check is obtained by comparing the residuals = observed outcomes - model predictions with the observed outcomes, also known as residual plot. In pp_check this check type is called error_scatter_avg:

Code
brms::pp_check(fit_epi_gaussian1, type = "error_scatter_avg")

Generative modelling with BRMS

Here there is a strongly almost perfectly linear relationship indicating strong problems with the independence assumption of the errors.

Essentially, both PP-checks have told us that our initial model is a very bad for the data at hand.

If you don’t know which check types are available, you can simply pass an arbitrary non-supported type name to get a list of all currently supported types:

brms::pp_check(fit_epi_gaussian1, type = "help_me")

Generative modelling with BRMS

Code
fit_epi_student1 <- 
  brms::brm(
    count ~ Trt * Base,
    data = epilepsy,
    family = brms::student()
    , silent = 2, seed = 8740, file = "fits/fit_epi_student1"
  )
summary(fit_epi_student1)
 Family: student 
  Links: mu = identity; sigma = identity; nu = identity 
Formula: count ~ Trt * Base 
   Data: epilepsy (Number of observations: 236) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat
Intercept     0.01      0.37    -0.71     0.74 1.00
Trt1          0.42      0.63    -0.82     1.62 1.00
Base          0.26      0.01     0.23     0.28 1.00
Trt1:Base    -0.11      0.02    -0.16    -0.07 1.00
          Bulk_ESS Tail_ESS
Intercept     2992     3117
Trt1          2133     2665
Base          2672     2228
Trt1:Base     1920     2163

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat
sigma     2.10      0.21     1.71     2.54 1.00
nu        1.39      0.18     1.09     1.78 1.00
      Bulk_ESS Tail_ESS
sigma     2557     2421
nu        2246     1553

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Code
brms::pp_check(fit_epi_student1) + xlim(-30, 30)
Warning: Removed 55 rows containing non-finite outside the
scale range (`stat_density()`).
Warning: Removed 7 rows containing non-finite outside the scale
range (`stat_density()`).

Model comparison

Generative modelling with BRMS

Absolute predictive performance

Code
# fit with a covariate
fit_epi_gaussian2 <- brms::brm(
  count ~ Trt + Base, data = epilepsy, silent = 2, seed = 8740, file = "fits/fit_epi_gaussian2"
  )
# fit with a covariate and interaction
fit_epi_gaussian3 <- brms::brm(
  count ~ Trt * Base, data = epilepsy, silent = 2, seed = 8740, file = "fits/fit_epi_gaussian3"
  )
brms::pp_check(fit_epi_gaussian2) + theme_minimal(base_size = 18) + theme(legend.position = "side")
Using 10 posterior draws for ppc type 'dens_overlay' by default.

Default posterior predictive check for model fit_epi_gaussian2
Code
brms::pp_check(fit_epi_gaussian2, type = "error_scatter_avg") + theme_minimal(base_size = 18)
Using all posterior draws for ppc type 'error_scatter_avg' by default.

Posterior predictive check comparing observed responses (y-axis) with the predictive residuals (x-axis) of model fit_epi_gaussian2

Generative modelling with BRMS

Measures of explained variance: \(R^2\)

The standard \(R^2_{\mathrm{basic}}\) measure is defined only for Gaussian models and is often referred to as the percentage of explained variance. Note that in the Gaussian context we compute draws from the posterior \(R^2_{\mathrm{basic}}\)

# compute draws of the predictive errors based on posterior_epred
errors <- brms::predictive_error(fit_epi_gaussian2, method = "posterior_epred")
str(errors)
 num [1:4000, 1:236] 1.36 3.07 3.35 3.2 3.09 ...
# sum errors over observations
error_variation <- rowSums(errors^2)
str(error_variation)
 num [1:4000] 15385 15225 15062 15072 15047 ...
# compute R2_basic
overall_variation <- sum((epilepsy$count - mean(epilepsy$count))^2)
R2_basic_epi_gaussian2 <- 1 - error_variation / overall_variation
brms::posterior_summary(R2_basic_epi_gaussian2)
     Estimate Est.Error  Q2.5  Q97.5
[1,]   0.5765  0.004541 0.565 0.5816

Generative modelling with BRMS

Measures of explained variance: \(R^2\)

The \(R^2_{\mathrm{basic}}\)is a good starting point, but it doesn’t generalize to models that are more complicated than Gaussian linear models. In particular it doesn’t readily generalize to most other likelihood families.

We’ll use a more general form of \(R^2\) that we can apply to (almost) all brms models, regardless of what likelihood families they have. The measure is based on the ratio of explained variance and the sum of explained and error variance:

\[ R^2_{\mathrm{general}} = \frac{\mathrm{Var}(\hat{y})}{\mathrm{Var}(\hat{y})+$\mathrm{Var}(\hat{e})} \]

Where \(\mathrm{Var}(\hat{y})\)) is the variance of the posterior predicted mean over observations (again posterior_epred) and \(\mathrm{Var}(\hat{e})\))) is the variance of the model-implied errors over observations, where \(\hat{e}=y_n-\hat{y}_n\).

brms::bayes_R2(fit_epi_gaussian2)
   Estimate Est.Error Q2.5  Q97.5
R2   0.5795   0.02776 0.52 0.6274

Generative modelling with BRMS

Measures of Squared Errors

\(\mathrm{RMSE}_{\mathrm{basic}}\) computes a mean square error of observations \(n\) for each posterior draw \(s\), and thus yields a posterior distribution over RMSE values

# compute draws of the predictive errors based on posterior_epred
errors_epi_gaussian3 <-
brms::predictive_error(fit_epi_gaussian3, method = "posterior_epred")
str(errors_epi_gaussian3)
 num [1:4000, 1:236] 1.15 1.67 1.53 2.42 2.19 ...
# root mean of squared errors over observations
rmse_basic_epi_gaussian3 <- sqrt(rowMeans(errors_epi_gaussian3^2))
str(rmse_basic_epi_gaussian3)
 num [1:4000] 7.72 7.71 7.71 7.73 7.71 ...
Code
lattice::histogram(rmse_basic_epi_gaussian3) 

Posterior histogram of RMSE_basic for model fit_epi_gaussian3.

Generative modelling with BRMS

Measures of Squared Errors

We can also exchange the use of \(n\) and \(s\) and compute a mean square error over draws \(s\) for each observation \(n\):

rmse_alt_epi_gaussian3 <- 
  sqrt(colMeans(errors_epi_gaussian3^2))
str(rmse_alt_epi_gaussian3)
 num [1:236] 1.96 0.961 1.014 1.813 10.987 ...
Code
lattice::histogram(rmse_alt_epi_gaussian3) 

Posterior histogram of RMSE_alt for model fit_epi_gaussian3.

In this case, we get a distribution of RMSE over observations, where each individual RMSE value would be computed over the posterior predictive distribution of a single observation. Both of the above RMSE measures are fully Bayesian as they take into account the uncertainty in the posterior distribution, but in different ways.

Generative modelling with BRMS

Measures of Squared Errors

Typically see only a point estimate \(\hat{\bar{y}}_n\) being used to represent the model-implied predictions, instead of a (posterior) distribution over such predictions for each \(n\). For example, for a Bayesian model this point estimate could simply be the posterior mean.

When using such a point prediction approach, our RMSE definition becomes:

# extract a point estimate of the predictions per observation
ppmean_epi_gaussian3 <- colMeans(brms::posterior_epred(fit_epi_gaussian3))
str(ppmean_epi_gaussian3)
 num [1:236] 3.27 3.27 1.94 2.47 17.91 ...
# compute RMSE based on the responses and point predictions
rmse_point_epi_gaussian3 <- sqrt(mean((epilepsy$count - ppmean_epi_gaussian3)^2))

Generative modelling with BRMS

Relative predictive performance

In general, it is more common to compare multiple models against each other and thus investigate their relative predictive performance.

errors_epi_student1 <-
  brms::predictive_error(fit_epi_student1, method = "posterior_epred")
rmse_alt_epi_student1 <- sqrt(colMeans(errors_epi_student1^2))

We can now even compute the pointwise (per-observation) difference in RMSE values:

rmse_alt_diff <- rmse_alt_epi_student1 - rmse_alt_epi_gaussian3
str(rmse_alt_diff)
 num [1:236] 0.219 -0.627 -0.462 0.144 -0.966 ...
se_mean <- function(x) {sd(x) / sqrt(length(x))}

se_rmse_alt_diff <- se_mean(rmse_alt_diff)
se_rmse_alt_diff
[1] 0.367

Generative modelling with BRMS

Likelihood Density Scores

We have looked at variations of \(R^2\) and RMSE metrics. Next we use the log-likelihood of models as predictive metric more generally.

The log-likelihood plays a pivotal role not only in to derive the posterior in Bayesian statistics but also to obtain maximum likelihood estimates in a frequentist framework. Intuitively, the higher the likelihood of the data given the model’s parameters estimates (represented as either posterior draws or point estimates), the better the fit of the model to the data. Many important predictive metrics, Bayesian or otherwise, are based on log-likelihood scores.

Generative modelling with BRMS

Likelihood Density Scores

Since log is a strictly monotonic transformation, we are not changing anything fundamental by looking at log likelihoods instead of likelihoods. However, we are making the math much simpler by working with sums instead of products. In particular, this concerns computing gradients because the gradient of a sum is just the sum of the individual (pointwise) gradients. Much of the modern statistics and ML relies on this property.

brms comes with a dedicated log_lik method that does all the required math.

ll_epi_gaussian3 <- brms::log_lik(fit_epi_gaussian3)
str(ll_epi_gaussian3)
 num [1:4000, 1:236] -3 -2.97 -2.95 -2.98 -3.02 ...
 - attr(*, "dimnames")=List of 2
  ..$ : NULL
  ..$ : NULL

Generative modelling with BRMS

Likelihood Density Scores

The output of log_lik has the same structure as posterior_predict and friends, that is, it has as many columns as we have observations and as many rows as we posterior draws.

Code
lattice::histogram(colMeans(ll_epi_gaussian3), nint=30, type = "density")

Per-observation (pointwise) log-likelihoods of model fit_epi_gaussian3.

Generative modelling with BRMS

Likelihood Density Scores

Similar to the RMSE_alt metric earlier, we average over posterior draws per observation such that we obtain one log-likelihood value per observation.

The mean of the log-likelihood differences is slightly positive which points to a slightly better fit of the Student-t model.

Code
llm_epi_student1  <- 
  colMeans( brms::log_lik(fit_epi_student1) )
llm_epi_gaussian3 <- 
  colMeans( brms::log_lik(fit_epi_gaussian3) )

llm_epi_diff <- 
  llm_epi_student1 - llm_epi_gaussian3
mean(llm_epi_diff)
[1] 0.5009
Code
llm_epi_gaussian3 <- colMeans(ll_epi_gaussian3)
lattice::histogram(~llm_epi_gaussian3 + llm_epi_student1 + llm_epi_diff, nint=50, type = "density")

Pointwise log-likelihood values of model fit_epi_gaussian3 and fit_epi_student1 as well as their pointwise log-likelihood differences.

Generative modelling with BRMS

Likelihood Density Scores

It is more typical to work with sums instead of means of log-likelihood values over observations, a quantity that we call LPD1:

lpd_epi_diff <- sum(llm_epi_diff)

The corresponding standard error is also not particular difficult to obtain:

se_sum <- function(x) {sd(x) * sqrt(length(x))}

se_lpd_epi_diff <- se_sum(llm_epi_diff)

Generative modelling with BRMS

2.7 Out-of-sample predictions

Code
set.seed(8973)
splits <- epilepsy |> rsample::initial_split(prop = 0.8)
epilepsy_train <- rsample::training(splits)
epilepsy_test  <- rsample::testing(splits)
Code
fit_student1_train <- 
  brms::brm(
    count ~ Trt * Base, data = epilepsy_train, family = brms::student()
    , silent = 2, seed = 8740, file = "fits/fit_student1_train")

llm_epi_student1_test <-  brms::log_lik(fit_student1_train, newdata = epilepsy_test)

fit_gaussian3_train <- 
  brms::brm(
    count ~ Trt * Base, data = epilepsy_train
    , silent = 2, seed = 8740, file = "fits/fit_gaussian3_train")

llm_epi_gaussian3_test <- colMeans( brms::log_lik(fit_gaussian3_train, newdata = epilepsy_test) )

Generative modelling with BRMS

Out-of-sample predictions

Code
llm_epi_diff_test <- llm_epi_student1_test - llm_epi_gaussian3_test
lattice::histogram(llm_epi_diff_test |> matrix(), nint=50, type = "density")

Code
(elpd_epi_diff_test <- sum(llm_epi_diff_test))
[1] 49480
Code
(se_elpd_epi_diff_test <- se_sum(llm_epi_diff_test))
[1] 742.9

Generative modelling with BRMS

Out-of-sample predictions

In leave-one-out cross-validation (LOO-CV), we perform \(N\) training-test splits, where each time we are leaving out a single observations, fitting the model on the remaining \(N-1\) observations before evaluating model fit on that single left-out observation.

loo_epi_gaussian3 <- brms::loo(fit_epi_gaussian3)
loo_epi_gaussian3

Computed from 4000 by 236 log-likelihood matrix.

         Estimate   SE
elpd_loo   -831.4 43.0
p_loo        21.5 12.5
looic      1662.8 86.1
------
MCSE of elpd_loo is NA.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.1]).

Pareto k diagnostic values:
                         Count Pct.    Min. ESS
(-Inf, 0.7]   (good)     234   99.2%   1310    
   (0.7, 1]   (bad)        0    0.0%   <NA>    
   (1, Inf)   (very bad)   2    0.8%   <NA>    
See help('pareto-k-diagnostic') for details.
loo_epi_student1 <- brms::loo(fit_epi_student1)
loo_epi_student1

Computed from 4000 by 236 log-likelihood matrix.

         Estimate   SE
elpd_loo   -704.5 23.5
p_loo         7.2  0.6
looic      1409.0 46.9
------
MCSE of elpd_loo is 0.1.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.5, 1.3]).

All Pareto k estimates are good (k < 0.7).
See help('pareto-k-diagnostic') for details.
brms::loo_compare(loo_epi_gaussian3, loo_epi_student1)
                  elpd_diff se_diff
fit_epi_student1     0.0       0.0 
fit_epi_gaussian3 -126.9      37.0 
brms::loo(fit_epi_gaussian3, fit_epi_student1)
Warning: Found 2 observations with a pareto_k > 0.7 in
model 'fit_epi_gaussian3'. We recommend to set
'moment_match = TRUE' in order to perform moment
matching for problematic observations.
Output of model 'fit_epi_gaussian3':

Computed from 4000 by 236 log-likelihood matrix.

         Estimate   SE
elpd_loo   -831.4 43.0
p_loo        21.5 12.5
looic      1662.8 86.1
------
MCSE of elpd_loo is NA.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.1]).

Pareto k diagnostic values:
                         Count Pct.    Min. ESS
(-Inf, 0.7]   (good)     234   99.2%   1310    
   (0.7, 1]   (bad)        0    0.0%   <NA>    
   (1, Inf)   (very bad)   2    0.8%   <NA>    
See help('pareto-k-diagnostic') for details.

Output of model 'fit_epi_student1':

Computed from 4000 by 236 log-likelihood matrix.

         Estimate   SE
elpd_loo   -704.5 23.5
p_loo         7.2  0.6
looic      1409.0 46.9
------
MCSE of elpd_loo is 0.1.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.5, 1.3]).

All Pareto k estimates are good (k < 0.7).
See help('pareto-k-diagnostic') for details.

Model comparisons:
                  elpd_diff se_diff
fit_epi_student1     0.0       0.0 
fit_epi_gaussian3 -126.9      37.0 

Generative modelling with BRMS

Prior predictive performance

We will now take a fundamentally different approach by evaluating prior predictive distribution, that is, by looking at what predictions a model implies before seeing any data. The starting point to investigating prior predictive performance is to perform graphical prior predictive checks.

Code
prior_epi_gaussian6 <-
  brms::prior(normal(6, 3), class = "Intercept") +
  brms::prior(normal(0, 5), class = "b", coef = "Trt1") +
  brms::prior(normal(0, 1), class = "b", coef = "Base") +
  brms::prior(normal(0, 1), class = "b", coef = "Trt1:Base") +
  brms::prior(normal(0, 15), class = "sigma")
prior_epi_gaussian6 
         prior     class      coef group resp dpar
  normal(6, 3) Intercept                          
  normal(0, 5)         b      Trt1                
  normal(0, 1)         b      Base                
  normal(0, 1)         b Trt1:Base                
 normal(0, 15)     sigma                          
 nlpar   lb   ub source
       <NA> <NA>   user
       <NA> <NA>   user
       <NA> <NA>   user
       <NA> <NA>   user
       <NA> <NA>   user

Generative modelling with BRMS

Prior predictive performance

Here we are “fitting” the model with the option sample_prior = “only”, which ensures that Stan ignores the likelihood contribution to the posterior, such that the posterior directly resembles the prior.

Code
fit_prior_epi_gaussian6 <- 
  brms::brm(
    count ~ 1 + Trt * Base,
    data = epilepsy,
    prior = prior_epi_gaussian6,
    sample_prior = "only",
    file = "fits/fit_prior_epi_gaussian6"
  )
summary(fit_prior_epi_gaussian6)
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: count ~ 1 + Trt * Base 
   Data: epilepsy (Number of observations: 236) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat
Intercept     4.68     36.97   -69.25    80.33 1.00
Trt1         -0.09      5.05    -9.95     9.64 1.00
Base          0.03      1.02    -1.99     2.12 1.00
Trt1:Base     0.02      1.01    -1.97     2.02 1.00
          Bulk_ESS Tail_ESS
Intercept     4020     2479
Trt1          4529     2679
Base          4120     2762
Trt1:Base     3915     1971

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat
sigma    11.91      9.22     0.42    34.07 1.00
      Bulk_ESS Tail_ESS
sigma     2673     1598

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Generative modelling with BRMS

Prior predictive performance

For all practical purposes, our “prior-only” brms model can be post-processed as any other brms model.

brms::pp_check(fit_prior_epi_gaussian6, ndraws = 100) + xlim(-150, 150)

Generative modelling with BRMS

Marginal likelihood-based metrics

To mathematically formalize the prior predictive performance of a model, consider the marginal likelihood that we find in the denominator of Bayes theorem (aka evidence): \[ p(y)=\int p(y\vert\theta)p(\theta)d\theta \] We can write this as the likelihood of the data given the model and we can do inference about the models based on the marginal likelihood. \[ p(y\vert M)=\int p(y\vert\theta,M)p(\theta\vert M)d\theta \] The absolute marginal likelihood values \(p(y\vert M)\)are very hard to interpret. We only know that higher is better.

Generative modelling with BRMS

Marginal likelihood-based metrics

The most common such comparative metric is the Bayes factor, defined as the ratio of two models’ marginal likelihoods:

\[ \mathrm{BF}_{1,2}=\frac{p(y\vert M_{1})}{p(y\vert M_{2})} \] If the Bayes factor is greater than 1, the data \(y\) have a higher likelihood given model \(M_1\) compared to \(M_2\), and vice versa.

Generative modelling with BRMS

Marginal likelihood-based metrics

We can say “Given the data \(y\), model \(M_1\) is more likely than model \(M_2\)” if we use the posterior odds:

\[ \frac{p(M_{1}\vert y)}{p(M_{2}\vert y)}=\frac{p(y\vert M_{1})}{p(y\vert M_{2})}\frac{p(M_{1})}{p(M_{2})}=\mathrm{BF}_{1,2}\times\frac{p(M_{1})}{p(M_{2})} \] We often set the prior odds to 1 not actually because we really believe in models being equally likely a priori, but simply out of convenience; just as we often set wide or even completely flat priors on parameters.

Generative modelling with BRMS

Marginal likelihood-based metrics

Since the Bayes factor is based on marginal likelihoods, the computational challenges are substantial. Fortunately, there is one class of algorithms that enables reliable computation of (log) marginal likelihood on the basis of posterior draws. This class of algorithms is called bridge sampling.

Marginal likelihood estimation via bridge sampling usually requires several times more posterior draws than the estimation of posterior moments or quantiles (i.e., what we usually do with posterior draws).

Code
fit_epi_gaussian6 <- 
  brms::brm(
    count ~ 1 + Trt * Base, data = epilepsy,
    prior = prior_epi_gaussian6,
    save_pars = brms::save_pars(all = TRUE),
    iter = 5000, warmup = 1000,
    file = "fits/fit_epi_gaussian6"
  )

logml_epi_gaussian6 <- brms::bridge_sampler(fit_epi_gaussian6, silent = TRUE);
summary(logml_epi_gaussian6)

Bridge sampling log marginal likelihood estimate 
(method = "normal", repetitions = 1):

 -831.3

Error Measures:

 Relative Mean-Squared Error: 6.947e-07
 Coefficient of Variation: 0.0008335
 Percentage Error: 0%

Note:
All error measures are approximate.

More

Recap

  • We’ve had the smallest possible taste of statistical programming using Bayes theorem and sampling methods, in the context of addressing the limitations of off-the-shelf implementations of statistical methods and algorithms.