This document follows on from previous files that show how to simulate data from an SEIR model. Start by loading in packages
library(R2jags)
library(runjags)
library(mcmcplots)
library(tidyverse)
library(gridExtra)
Here is the code to simulate an SEIR model:
seir_sim = '
data {
# Likelihood
for (t in 1:T) {
N_S_E[t] ~ dpois(beta_E * E[t] * S[t] / N)
N_S_I[t] ~ dpois(beta_I * I[t] * S[t] / N)
N_E_I[t] ~ dbinom(p_E_I, E[t])
N_I_R[t] ~ dbinom(p_I_R, I[t])
}
# These are the known time evolution steps:
for(t in 2:T) {
S[t] = S[t-1] - N_S_E[t-1] - N_S_I[t-1]
E[t] = E[t-1] + N_S_E[t-1] + N_S_I[t-1] - N_E_I[t-1]
I[t] = I[t-1] + N_E_I[t-1] - N_I_R[t-1]
R[t] = R[t-1] + N_I_R[t-1]
}
# Need a value for S[1], E[1], and I[1]
I[1] = I_start
E[1] = E_start
R[1] = R_start
S[1] = N - I[1] - R[1] # Left over
# Probabilities and R_0
p_E_I = 1 - exp( - gamma_E )
p_I_R = 1 - exp( - gamma_I )
R_0 = (beta_E/gamma_E) + (beta_I/gamma_I)
}
model{
fake = 0
}
'
Now set up the parameters for that model. This has constant R0.
N = 4.9*10^6 # Population size
T = 200 # Maximum time steps
# Length of time exposed
gamma_E_inv = 6.6; gamma_E = 1/gamma_E_inv # Days / Inverse days
# Length of time infected
gamma_I_inv = 7.4; gamma_I = 1/gamma_I_inv
# Transmission rates for Exposed/Infected
beta_E_inv = 5; beta_E = 1/beta_E_inv
beta_I_inv = 5; beta_I = 1/beta_I_inv
# R0 is a function of these values
R_0 = beta_E/gamma_E + beta_I/gamma_I
# Starting values for each component
I_start = 10
E_start = 0
R_start = 0
We can then set up a list and use these to simulate data
# parameters are treated as data for the simulation step
data = list(N = N, T = T,
gamma_E = gamma_E, gamma_I = gamma_I,
beta_E = beta_E, beta_I = beta_I,
# These are the starting values for each component
I_start = I_start, E_start = E_start, R_start = R_start)
# run jags
out = run.jags(seir_sim,
data = data,
monitor = c("S", "E", "I", "R"),
sample = 1, n.chains = 1, summarise=FALSE)
## Compiling rjags model...
## Calling the simulation using the rjags method...
## Note: the model did not require adaptation
## Burning in the model for 4000 iterations...
## Running the model for 1 iterations...
## Simulation complete
## Finished running the simulation
Simulated = coda::as.mcmc(out)
This, quite unhelpfully, produces a long vector containing first the S compartment, then the E, then the I, then the R
We can sort it out with:
dat = tibble(simulation = as.vector(Simulated))
dat$compartment = c(rep("S",T), rep("E",T), rep("I",T),rep("R",T))
dat$t = rep(1:T, 4)
ggplot(dat, aes(x = t, y = simulation, colour = compartment)) +
geom_line()
It’s interesting how long it takes to hit exponential growth.
I’m going to do three versions of fitting the model to simulated data:
N_E_I
and N_I_R
and the starting values, and see if it can estimate all the other parameters (with good prior distributions)N_E_I
and N_I_R
and the wrong starting values, to see if it can still estimate the other parameters okN_E_I
with strong prior distribution on the other parameters and seeing if it can estimate everything elseHere is the JAGS code I will use for fitting these models. The prior distributions at the bottom have fixed hyper-parameters to enable me to control what you can learn or not from the model.
seir_model = '
model {
# Likelihood
for (t in 1:T) {
N_S_E[t] ~ dpois(beta_E * E[t] * S[t] / N)
N_S_I[t] ~ dpois(beta_I * I[t] * S[t] / N)
N_E_I[t] ~ dbinom(p_E_I, E[t])
N_I_R[t] ~ dbinom(p_I_R, I[t])
}
# These are the known time evolution steps:
for(t in 2:T) {
S[t] = S[t-1] - N_S_E[t-1] - N_S_I[t-1]
E[t] = E[t-1] + N_S_E[t-1] + N_S_I[t-1] - N_E_I[t-1]
I[t] = I[t-1] + N_E_I[t-1] - N_I_R[t-1]
R[t] = R[t-1] + N_I_R[t-1]
}
# Need a value for S[1], E[1], and I[1]
I[1] = I_start
E[1] = E_start
R[1] = R_start
S[1] = N - I[1] - R[1] # Left over
# Probabilities and R_0
p_E_I = 1 - exp( - gamma_E )
p_I_R = 1 - exp( - gamma_I )
R_0 = (beta_E/gamma_E) + (beta_I/gamma_I)
# Prior distributions
gamma_E ~ dgamma(gamma_E_1, gamma_E_2)
gamma_I ~ dgamma(gamma_I_1, gamma_I_2)
beta_E ~ dgamma(gamma_beta_E_1, gamma_beta_E_2)
beta_I ~ dgamma(gamma_beta_I_1, gamma_beta_I_2)
}
'
Finally, before starting this exercise, I’m going to simulate the transition data rather than the compartment values from the model:
# Simulate what would actually be the data
# set.seed(123)
# out2 = run.jags(seir_sim,
# data = data,
# monitor = c("N_E_I", "N_I_R", "N_S_E", "N_S_I"),
# sample = 1, n.chains = 1, summarise=FALSE)
# saveRDS(out2, file = 'simulated_inputs.rds')
out2 = readRDS(file = 'simulated_inputs.rds')
Simulated2 = coda::as.mcmc(out2)
dat2 = tibble(simulation = as.vector(Simulated2))
dat2$compartment = c(rep("N_E_I",T), rep("N_I_R",T), rep("N_S_E",T),rep("N_S_I",T))
dat2$t = rep(1:T, 4)
ggplot(dat2, aes(x = t, y = simulation, colour = compartment)) +
geom_line()
dat3 = dat2 %>%
pivot_wider(names_from = compartment, values_from = simulation)
Giving the model N_E_I
and N_I_R
and all the other parameters, and see if it can estimate N_S_E
and N_S_I
fac = 10 # Factor to multiple up the gamma distributions. Larger values give lower sd
jags_run_data = list(N = N, T = T,
gamma_E_1 = gamma_E*fac, # These enable me to fix the mean and have small sd
gamma_E_2 = fac,
gamma_I_1 = gamma_I*fac, # These enable me to fix the mean and have small sd
gamma_I_2 = fac,
gamma_beta_E_1 = beta_E*fac, # These enable me to fix the mean and have small sd
gamma_beta_E_2 = fac,
gamma_beta_I_1 = beta_I*fac, # These enable me to fix the mean and have small sd
gamma_beta_I_2 = fac,
N_E_I = dat3$N_E_I,
N_I_R = dat3$N_I_R,
# These are the starting values for each component
I_start = I_start, E_start = E_start, R_start = R_start)
# Set the starting values to be decent
inits = function() {
with(dat3,
list(N_S_E = N_S_E,
N_S_I = N_S_I))
}
# run jags
# run1 = jags(data = jags_run_data,
# inits = inits,
# parameters.to.save = c("N_S_E", "N_S_I", "beta_E", "beta_I", "gamma_E", "gamma_I"),
# model.file = textConnection(seir_model))
# saveRDS(run1, file = 'jags_run_sier_s1.rds')
run1 = readRDS(file = 'jags_run_sier_s1.rds')
That ran very very slowly. See if it converged
plot(run1)
Now see if the values of N_S_E
and N_S_I
match the true
post_N_S_E = run1$BUGSoutput$median$N_S_E
post_N_S_I = run1$BUGSoutput$median$N_S_I
p1 = qplot(dat3$N_S_E, post_N_S_E)
p2 = qplot(dat3$N_S_I, post_N_S_I)
grid.arrange(p1, p2, ncol=2)
and see if the parameter values match the truth
post_beta_E = run1$BUGSoutput$sims.list$beta_E
post_beta_I = run1$BUGSoutput$sims.list$beta_I
post_gamma_E = run1$BUGSoutput$sims.list$gamma_E
post_gamma_I = run1$BUGSoutput$sims.list$gamma_I
p3 = qplot(post_beta_E, geom = 'histogram') +
geom_vline(xintercept = beta_E) +
geom_histogram(aes(x = rgamma(10000, beta_E*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
p4 = qplot(post_beta_I, geom = 'histogram') +
geom_vline(xintercept = beta_I) +
geom_histogram(aes(x = rgamma(10000, beta_I*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
p5 = qplot(post_gamma_E, geom = 'histogram') +
geom_vline(xintercept = gamma_E) +
geom_histogram(aes(x = rgamma(10000, gamma_E*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
p6 = qplot(post_gamma_I, geom = 'histogram') +
geom_vline(xintercept = gamma_I) +
geom_histogram(aes(x = rgamma(10000, gamma_I*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
grid.arrange(p3, p4, p5, p6, ncol=2)
Looks pretty good!
Now going to give it N_E_I
and N_I_R
and the wrong starting values, to see if it can still estimate the other parameters ok.
fac = 10 # Factor to multiple up the gamma distributions. Larger values give lower sd
jags_run_data = list(N = N, T = T,
gamma_E_1 = gamma_E*fac, # These enable me to fix the mean and have small sd
gamma_E_2 = fac,
gamma_I_1 = gamma_I*fac, # These enable me to fix the mean and have small sd
gamma_I_2 = fac,
gamma_beta_E_1 = beta_E*fac, # These enable me to fix the mean and have small sd
gamma_beta_E_2 = fac,
gamma_beta_I_1 = beta_I*fac, # These enable me to fix the mean and have small sd
gamma_beta_I_2 = fac,
N_E_I = dat3$N_E_I,
N_I_R = dat3$N_I_R,
# These are the starting values for each component - HERE THEY ARE DELIBERATELY SPECIFIED WRONG!
I_start = I_start*10, E_start = E_start+10, R_start = R_start + 100)
# Set the starting values to be decent
inits = function() {
with(dat3,
list(N_S_E = N_S_E,
N_S_I = N_S_I))
}
# Run jags
# run2 = jags(data = jags_run_data,
# inits = inits,
# parameters.to.save = c("N_S_E", "N_S_I", "beta_E", "beta_I", "gamma_E", "gamma_I"),
# model.file = textConnection(seir_model))
# saveRDS(run2, file = 'jags_run_sier_s2.rds')
run2 = readRDS(file = 'jags_run_sier_s2.rds')
About 40 minutes to run
plot(run2)
. It did ok but not great convergence of the beta and gamma values. Now see if the values of N_S_E
and N_S_I
match the true
post_N_S_E = run2$BUGSoutput$median$N_S_E
post_N_S_I = run2$BUGSoutput$median$N_S_I
p1 = qplot(dat3$N_S_E, post_N_S_E)
p2 = qplot(dat3$N_S_I, post_N_S_I)
grid.arrange(p1, p2, ncol=2)
and see if the parameter values match the truth
post_beta_E = run2$BUGSoutput$sims.list$beta_E
post_beta_I = run2$BUGSoutput$sims.list$beta_I
post_gamma_E = run2$BUGSoutput$sims.list$gamma_E
post_gamma_I = run2$BUGSoutput$sims.list$gamma_I
p3 = qplot(post_beta_E, geom = 'histogram') +
geom_vline(xintercept = beta_E) +
geom_histogram(aes(x = rgamma(10000, beta_E*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
p4 = qplot(post_beta_I, geom = 'histogram') +
geom_vline(xintercept = beta_I) +
geom_histogram(aes(x = rgamma(10000, beta_I*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
p5 = qplot(post_gamma_E, geom = 'histogram') +
geom_vline(xintercept = gamma_E) +
geom_histogram(aes(x = rgamma(10000, gamma_E*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
p6 = qplot(post_gamma_I, geom = 'histogram') +
geom_vline(xintercept = gamma_I) +
geom_histogram(aes(x = rgamma(10000, gamma_I*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
grid.arrange(p3, p4, p5, p6, ncol=2)
Finally going to give it only N_E_I
with strong prior distribution on the other parameters to see if it can estimate everything else
fac = 10 # Factor to multiple up the gamma distributions. Larger values give lower sd
jags_run_data = list(N = N, T = T,
gamma_E_1 = gamma_E*fac, # These enable me to fix the mean and have small sd
gamma_E_2 = fac,
gamma_I_1 = gamma_I*fac, # These enable me to fix the mean and have small sd
gamma_I_2 = fac,
gamma_beta_E_1 = beta_E*fac, # These enable me to fix the mean and have small sd
gamma_beta_E_2 = fac,
gamma_beta_I_1 = beta_I*fac, # These enable me to fix the mean and have small sd
gamma_beta_I_2 = fac,
N_E_I = dat3$N_E_I,
# These are the starting values for each component - HERE THEY ARE DELIBERATELY SPECIFIED WRONG!
I_start = I_start, E_start = E_start, R_start = R_start)
# Set the starting values to be decent
inits = function() {
with(dat3,
list(N_S_E = N_S_E,
N_I_R = N_I_R,
N_S_I = N_S_I))
}
# Run jags
# run3 = jags(data = jags_run_data,
# inits = inits,
# parameters.to.save = c("N_S_E", "N_S_I", "N_I_R", "beta_E", "beta_I", "gamma_E", "gamma_I"),
# model.file = textConnection(seir_model))
# saveRDS(run3, file = 'jags_run_sier_s3.rds')
run3 = readRDS(file = 'jags_run_sier_s3.rds')
About 40 minutes to run
plot(run3)
. It did ok but not great convergence of the beta and gamma values. Now see if the values of N_S_E
and N_S_I
match the true
post_N_S_E = run3$BUGSoutput$median$N_S_E
post_N_S_I = run3$BUGSoutput$median$N_S_I
post_N_I_R = run3$BUGSoutput$median$N_I_R
p1 = qplot(dat3$N_S_E, post_N_S_E)
p2 = qplot(dat3$N_S_I, post_N_S_I)
p3 = qplot(dat3$N_I_R, post_N_I_R)
grid.arrange(p1, p2, p3, ncol=3)
and see if the parameter values match the truth
post_beta_E = run2$BUGSoutput$sims.list$beta_E
post_beta_I = run2$BUGSoutput$sims.list$beta_I
post_gamma_E = run2$BUGSoutput$sims.list$gamma_E
post_gamma_I = run2$BUGSoutput$sims.list$gamma_I
p3 = qplot(post_beta_E, geom = 'histogram') +
geom_vline(xintercept = beta_E) +
geom_histogram(aes(x = rgamma(10000, beta_E*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
p4 = qplot(post_beta_I, geom = 'histogram') +
geom_vline(xintercept = beta_I) +
geom_histogram(aes(x = rgamma(10000, beta_I*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
p5 = qplot(post_gamma_E, geom = 'histogram') +
geom_vline(xintercept = gamma_E) +
geom_histogram(aes(x = rgamma(10000, gamma_E*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
p6 = qplot(post_gamma_I, geom = 'histogram') +
geom_vline(xintercept = gamma_I) +
geom_histogram(aes(x = rgamma(10000, gamma_I*fac,fac), fill = 'red', alpha = 0.3)) +
theme(legend.position = 'None')
grid.arrange(p3, p4, p5, p6, ncol=2)