Introduction

In this document I’m going to fit a change-point SEIR model. The change-point will occur at a certain time based on the effective reproduction number \(R_e\). In this document I’m going to assume that there is only one change-point but it would be easy enough to fit models with different change-points. In the first part below I will simulate some data. In the second part I will fit the model assuming good starting values and reasonable priors. In the second part I will explore what happens when the starting values are specified incorrectly (or stochastically).

Part 1: Simulating from an SEIR model

Following on from previous files that show how to simulate data from an SEIR model, this is a version with a changing \(R_e\) value with a single change point at a known time. Of course in reality the change point has a lagged effect, but I will deal with that issue elsewhere. Start by loading in packages

library(R2jags)
library(runjags)
library(mcmcplots)
library(tidyverse)
library(gridExtra)

Here is the code to simulate an SEIR model:

seir_cp_sim = '
data {
  # Likelihood
  for (t in 1:T) {
    N_S_E[t] ~ dpois(beta[t] * E[t] * S[t] / N)
    N_S_I[t] ~ dpois(beta[t] * I[t] * S[t] / N)
    N_E_I[t] ~ dbinom(p_E_I, E[t])
    N_I_R[t] ~ dbinom(p_I_R, I[t])
    
    # Varying transmission rates. Note the exposed and infected transmission rates are assumed equal
    beta[t] = R_e[t] / (1/gamma_E + 1/gamma_I) 
    # Now varying R_e - changes at the change-point
    R_e[t] = exp(alpha[step(t - cp) + 1])
  }

  # These are the known time evolution steps:
  for(t in 2:T) {
    S[t] = S[t-1] - N_S_E[t-1] - N_S_I[t-1]
    E[t] = E[t-1] + N_S_E[t-1] + N_S_I[t-1] - N_E_I[t-1]
    I[t] = I[t-1] + N_E_I[t-1] - N_I_R[t-1]
    R[t] = R[t-1] + N_I_R[t-1]
  }

  # Need a value for S[1], E[1], and I[1]
  I[1] = I_start
  E[1] = E_start
  R[1] = R_start
  S[1] = N - I[1] - R[1] # Left over
  
  # Compartment holding times and probability of change
  p_E_I = 1 - exp( - gamma_E )
  p_I_R = 1 - exp( - gamma_I )
}
model{
  fake = 0
}
'

Now set up the parameters for that model. Assume the lockdown happens on day 50

N = 4.9*10^6 # Population size
T = 200 # Maximum time steps
cp = 50

# R_e before and after lockdown
alpha = log(c(5, 0.5))

# Length of time exposed
gamma_E_inv = 6.6; gamma_E = 1/gamma_E_inv # Days / Inverse days

# Length of time infected
gamma_I_inv = 7.4; gamma_I = 1/gamma_I_inv

# Starting values for each component
I_start = 10
E_start = 0
R_start = 0

We can then set up a list and use these to simulate data

# parameters are treated as data for the simulation step
data = list(N = N, T = T, 
            gamma_E = gamma_E, gamma_I = gamma_I, 
            alpha = alpha, cp = cp,
            # These are the starting values for each component
            I_start = I_start, E_start = E_start, R_start = R_start)

# run jags
out = run.jags(seir_cp_sim, 
               data = data, 
               monitor = c("S", "E", "I", "R"),
               sample = 1, n.chains = 1, summarise=FALSE)
## Compiling rjags model...
## Calling the simulation using the rjags method...
## Note: the model did not require adaptation
## Burning in the model for 4000 iterations...
## Running the model for 1 iterations...
## Simulation complete
## Finished running the simulation
Simulated = coda::as.mcmc(out)

This, quite unhelpfully, produces a long vector containing first the S compartment, then the E, then the I, then the R

We can sort it out with:

dat = tibble(simulation = as.vector(Simulated))
dat$compartment = c(rep("S",T), rep("E",T), rep("I",T),rep("R",T))
dat$t = rep(1:T, 4)
ggplot(dat, aes(x = t, y = simulation, colour = compartment)) +
  geom_line()

It’s interesting how long it takes to hit exponential growth.

Part 2 Fitting the SEIR change point model to simulated data

This is the code I will use to fit the model. It’s the same as above but includes some priors on the parameters:

seir_cp_model = '
model {
  # Likelihood
  for (t in 1:T) {
    N_S_E[t] ~ dpois(beta[t] * E[t] * S[t] / N)
    N_S_I[t] ~ dpois(beta[t] * I[t] * S[t] / N)
    N_E_I[t] ~ dbinom(p_E_I, E[t])
    N_I_R[t] ~ dbinom(p_I_R, I[t])
    
    # Varying transmission rates. Note the exposed and infected transmission rates are assumed equal
    beta[t] = R_e[t] / (1/gamma_E + 1/gamma_I) 
    # Now varying R_e - changes at the change-point
    R_e[t] = exp(alpha[step(t - cp) + 1])
  }

  # These are the known time evolution steps:
  for(t in 2:T) {
    S[t] = S[t-1] - N_S_E[t-1] - N_S_I[t-1]
    E[t] = E[t-1] + N_S_E[t-1] + N_S_I[t-1] - N_E_I[t-1]
    I[t] = I[t-1] + N_E_I[t-1] - N_I_R[t-1]
    R[t] = R[t-1] + N_I_R[t-1]
  }

  # Need a value for S[1], E[1], and I[1]
  I[1] = I_start
  E[1] = E_start
  R[1] = R_start
  S[1] = N - I[1] - R[1] # Left over

  # Probabilities and R_0  
  p_E_I = 1 - exp( - gamma_E )
  p_I_R = 1 - exp( - gamma_I )
  
  # Prior distributions
  gamma_E ~ dgamma(gamma_E_1, gamma_E_2)
  gamma_I ~ dgamma(gamma_I_1, gamma_I_2)
  
  # Prior distributions on the alpha values
  alpha[1] ~ dnorm(mean_before, sd_before^-2)
  alpha[2] ~ dnorm(mean_after, sd_after^-2)
}
'

Finally, before starting this exercise, I’m going to simulate the transition data rather than the compartment values from the model:

# Simulate what would actually be the data
# set.seed(123)
# out_cp_data = run.jags(seir_cp_sim,
#                        data = data,
#                        monitor = c("N_E_I", "N_I_R", "N_S_E", "N_S_I"),
#                        sample = 1, n.chains = 1, summarise=FALSE)
# saveRDS(out_cp_data, file = 'simulated_inputs_cp.rds')
out_cp_data = readRDS(file = 'simulated_inputs_cp.rds')
Simulated2 = coda::as.mcmc(out_cp_data)
dat2 = tibble(simulation = as.vector(Simulated2))
dat2$compartment = c(rep("N_E_I",T), rep("N_I_R",T), rep("N_S_E",T),rep("N_S_I",T))
dat2$t = rep(1:T, 4)
ggplot(dat2, aes(x = t, y = simulation, colour = compartment)) +
  geom_line()

dat3 = dat2 %>% 
  pivot_wider(names_from = compartment, values_from = simulation)

Part 1

Giving the model N_E_I only, and all the other parameters, and see if it can estimate the remaining transmission values.

fac = 10 # Factor to multiple up the gamma distributions. Larger values give lower sd
jags_run_data = list(N = N, T = T, cp = cp, 
                     gamma_E_1 = gamma_E*fac, # These enable me to fix the mean and have small sd
                     gamma_E_2 = fac,
                     gamma_I_1 = gamma_I*fac, # These enable me to fix the mean and have small sd
                     gamma_I_2 = fac,
                     mean_before = alpha[1],
                     mean_after = alpha[2],
                     sd_before = 1,
                     sd_after = 1,
                     N_E_I = dat3$N_E_I,
                     # These are the starting values for each component
                     I_start = I_start, E_start = E_start, R_start = R_start)

# Set the starting values to be decent
inits = function() {
  with(dat3, 
       list(N_S_E = N_S_E,
            N_I_R = N_I_R,
            N_S_I = N_S_I))
}

# run jags
# run_cp1 = jags(data = jags_run_data,
#                inits = inits,
#                parameters.to.save = c("N_S_E", "N_S_I", 
#                                       "N_I_R", "beta", "alpha",
#                                       "gamma_E", "gamma_I"),
#                model.file = textConnection(seir_cp_model))
# saveRDS(run_cp1, file = 'jags_run_sier_cp1.rds')
run_cp1 = readRDS(file = 'jags_run_sier_cp1.rds')

That ran very very slowly. See if it converged

plot(run_cp1)

Not great convergence, but let’s see if the values of N_S_E, N_I_R, and N_S_I match the true values:

post_N_S_E = run_cp1$BUGSoutput$median$N_S_E
post_N_S_I = run_cp1$BUGSoutput$median$N_S_I
post_N_I_R = run_cp1$BUGSoutput$median$N_I_R

p1 = qplot(dat3$N_S_E, post_N_S_E)
p2 = qplot(dat3$N_S_I, post_N_S_I)
p3 = qplot(dat3$N_I_R, post_N_I_R)
grid.arrange(p1, p2, p3, ncol=3)

and see if the parameter values match the truth

post_alpha = run_cp1$BUGSoutput$sims.list$alpha
post_gamma_E = run_cp1$BUGSoutput$sims.list$gamma_E
post_gamma_I = run_cp1$BUGSoutput$sims.list$gamma_I
p4 = qplot(post_alpha[,1], geom = 'histogram') + 
  geom_vline(xintercept = alpha[1]) + 
  geom_histogram(aes(x = rnorm(10000, alpha[1], 1), fill = 'red', alpha = 0.3)) + 
  theme(legend.position = 'None')
p5 = qplot(post_alpha[,2], geom = 'histogram') + 
  geom_vline(xintercept = alpha[2]) + 
  geom_histogram(aes(x = rnorm(10000, alpha[2], 1), fill = 'red', alpha = 0.3)) + 
  theme(legend.position = 'None')
p6 = qplot(post_gamma_E, geom = 'histogram') + 
  geom_vline(xintercept = gamma_E) + 
  geom_histogram(aes(x = rgamma(10000, gamma_E*fac,fac), fill = 'red', alpha = 0.3)) + 
  theme(legend.position = 'None')
p7 = qplot(post_gamma_I, geom = 'histogram') + 
  geom_vline(xintercept = gamma_I) + 
  geom_histogram(aes(x = rgamma(10000, gamma_I*fac,fac), fill = 'red', alpha = 0.3)) + 
  theme(legend.position = 'None')

grid.arrange(p4, p5, p6, p7, ncol=2)

Looks pretty good!