In this example we consider a simulator for the incidence of type-2 diabetes and stroke with various risk factors as well as mortality from stroke and other causes. Interventions that aim to reduce systolic blood pressure in the population are also considered

Data files

To run the example, the following data files are needed

The data files (with MONICA data unzipped) should be placed to directory called data in relation to the working directory for the purposes of this example.

R scripts

The R script files required to run the example can be downloaded at http://users.jyu.fi/~santikka/Sima/diabetes_and_stroke.zip

The zip file contains the following

  • diabetes_and_stroke.R is the primary script that contains the all the simulator definition and simulation runs.
  • events.R contains all the event definitions.
  • data.R contains the function to generate the initial population.
  • functions.R contains utility function definitions that are used throughout.

The contents of these scripts are also listed in the following sections for direct viewing

Main script

The file diabetes_and_stroke.R provides the code to run the entire example

## An example use case of the Sima framework constructing a simulator for 
## modelling interventions regarding salt use in a synthetic population with 
## Finnish population characteristics in the year 2017. Also includes
## interventions that aim to reduce systolic blood pressure in the population.

# Uncomment to install packages
# install.packages(c("data.table", "ggplot2", "bestNormalize", "devtools"))
# devtools::install_github("santikka/Sima")

library(Sima)
library(data.table)
library(ggplot2)
rm(list = ls())

# The file_path should be set to the location of the R-scripts, such that
# the data-directory is present in this location

# file_path <- ""
# setwd(file_path)

# Verify that all data files are available
if (!file.exists("data/vaesto_3112_2017.csv")) {
    stop("Data file vaesto_3112_2017.csv not found in data directory.")
}

if (!file.exists("data/kuol_007_201700.csv")) {
    stop("Data file kuol_007_201700 not found in data directory.")
}

if (!file.exists("data/form04_3.txt")) {
    stop("Data file form04_3.txt not found in data directory.")
}

#####################
# Utility Functions #
#####################

source("functions.R")

##########
# Events #
##########

source("events.R")

######################
# Initial Population #
######################

source("data.R")

##################
# Initialization #
##################

# Main simulator, set n to desired population size
sim <- Simulator$new(
    initializer = initializer,
    man_events = all_events,
    seeds = 242,
    keep_init = TRUE,
    n = 3.6e6
)

########################
# Parallel computation #
########################

## Uncomment to enable a parallel cluster for remaining computations
# if (require(doParallel)) {
#     nc <- parallel::detectCores()
#     cl <- parallel::makeCluster(nc)
#     doParallel::registerDoParallel(cl)
#     ## User made functions/objects have to be manually exported, list them in 'export'
#     sim$start_cluster(cl, nc, export = c("logit", "expit", "rbern", "rbern2"))
# }

###############
# Calibration #
###############

# Uncomment to perform calibration
# calib_result <- optim(par = sqrt(c(13.51056, 161.12780, 5.41859)), 
#                       fn = mortality_objective, 
#                       method = "Nelder-Mead", 
#                       calib_sim = sim)

# Using a population of 1 million individuals, the following optimum is found:
# c(12.03600, 171.32261, 12.52546)

# Configure the simulator to use the optimal parameters 
# Here, only 1st parameter is configured for stroke mortality
sim$reconfigure(list(
    "Other Mortality" = list(shape = 12.03600, scale = 171.32261), 
    "Stroke Mortality" = list(initial = 12.52546)
))

# Visualize and compare simulated mortality to official statistics
sim$run(364)
sim$stop_cluster()

mortality <- read.csv(file = "data/kuol_007_201700.csv", header = TRUE, sep = ",")
mortality[,3] <- mortality[,3]/1000
n <- nrow(mortality)
mortality <- mortality[mortality$Age >= 31,]
mortality_sim <- mortality_df(sim$get_status(), min_age = 30)$out

out_df <- data.frame(
    age = rep(mortality_sim$age_group, 2),
    mortality = c(mortality_sim$deaths/mortality_sim$size, mortality[,3]),
    group = gl(2, nrow(mortality), labels = c("Simulated", "Reported")))

# Plot comparing mortality in the synthetic population to the official statistics
g <- ggplot(out_df, aes(x = age, y = mortality, linetype = group)) + 
    geom_line(lwd = 0.5) +
    theme_bw(base_size = 16) +
    theme(
        legend.position = c(0.190, 0.85),
        legend.key.width = unit(1.75, "cm"),
        legend.key.size = unit(1.5, "line"),
        legend.text = element_text(size = 14),
        plot.margin = unit(c(5.5, 12, 5.5, 5.5), "points"),
        legend.background = element_rect(fill = "white", size = 2, linetype = "solid")
    ) +
    scale_x_continuous(limits = c(30, 100), breaks = 3:10 * 10, expand = c(0,0)) +
    scale_y_continuous(limits = c(0, 0.4), expand = c(0.025, 0)) +
    scale_linetype_manual(name = NULL, values = c("longdash", "solid")) +
    labs(x = "Age", y = "Mortality", title = NULL)
g

######################
# Selective sampling #
######################

# Create a study population corresponding to the target sample 
sim_sample <- Simulator$new(
    initializer = initializer,
    man_events = all_events,
    keep_init = TRUE,
    seeds = 242,
    n = 3.6e6
)

# Configure the simulator to use the optimal parameters
# Here, only 1st parameter is configured for stroke mortality
sim_sample$reconfigure(list(
    "Other Mortality" = list(shape = 12.03600, scale = 171.32261), 
    "Stroke Mortality" = list(initial = 12.52546)
))

missingness <- function(dt) {
    beta <- c(-9.484330, -0.311690, 0.113614, -0.090696, 0.036383, 0.689770)
    X <- as.matrix(dt[ ,.(const = 1, sex, age, bmi, waist_circumference, smoking)])
    rbern(expit(as.numeric(X %*% beta)))
}

pop <- data.table::copy(sim_sample$get_status())
m <- min(1e4, nrow(pop))
baseline_sample <- pop$id %in% sample(1:m, m, replace = FALSE)
baseline_nonresp <- missingness(pop[baseline_sample,])
baseline_stroke <- pop[baseline_sample,stroke]

## Uncomment to enable a parallel cluster for remaining computations
# if (require(doParallel)) {
#     nc <- parallel::detectCores()
#     cl <- parallel::makeCluster(nc)
#     doParallel::registerDoParallel(cl)
#     ## User made functions/objects have to be manually exported, list them in 'export'
#     sim_sample$start_cluster(cl, nc, export = c("logit", "expit", "rbern", "rbern2"))
# }

sim_sample$run(10 * 365)
sim_sample$stop_cluster()

# Update stroke status after 10 years
pop_10y <- sim_sample$get_status()[order(id),]
pop[baseline_sample, stroke := pop_10y[baseline_sample,stroke]]

# Comparing two subsets: one with dropout (miss), one without (full)
# (and one with the entire population, if included, otherwise same as full)
included_miss <- baseline_sample & baseline_nonresp == 0 & baseline_stroke == 0
included_full <- baseline_sample & baseline_stroke == 0

# Separate models for men and women
model_miss_m <- glm(stroke ~ age + smoking + blood_pressure + hdl_cholesterol + 
                        diabetes + parent_stroke, 
                    data = pop[included_miss & sex == 0, ], family = binomial())
model_miss_f <- glm(stroke ~ age + smoking + blood_pressure + hdl_cholesterol + 
                        diabetes + parent_stroke, 
                    data = pop[included_miss & sex == 1, ], family = binomial())

model_full_m <- glm(stroke ~ age + smoking + blood_pressure + hdl_cholesterol + 
                        diabetes + parent_stroke, 
                    data = pop[included_full & sex == 0, ], family = binomial())
model_full_f <- glm(stroke ~ age + smoking + blood_pressure + hdl_cholesterol + 
                        diabetes + parent_stroke, 
                    data = pop[included_full & sex == 1, ], family = binomial())

# model_pop_m <- glm(stroke ~ age + smoking + blood_pressure + hdl_cholesterol + 
#                       diabetes + parent_stroke, 
#                    data = pop[included_pop & sex == 0, ], family = binomial())
# model_pop_f <- glm(stroke ~ age + smoking + blood_pressure + hdl_cholesterol + 
#                       diabetes + parent_stroke, 
#                    data = pop[included_pop & sex == 1, ], family = binomial())

summary(model_miss_m)
summary(model_miss_f)
summary(model_full_m)
summary(model_full_f)
# summary(model_pop_f)
# summary(model_pop_m)

round(exp(coef(model_miss_m)), 2)
round(exp(coef(model_miss_f)), 2)
round(exp(coef(model_full_m)), 2)
round(exp(coef(model_full_f)), 2)
# round(exp(coef(model_pop_m)), 2)
# round(exp(coef(model_pop_f)), 2)

round(exp(confint(model_miss_m)), 2)
round(exp(confint(model_miss_f)), 2)
round(exp(confint(model_full_m)), 2)
round(exp(confint(model_full_f)), 2)
# round(exp(confint(model_pop_m)), 2)
# round(exp(confint(model_pop_f)), 2)

#################
# Interventions #
#################

# Set the interventions for each scenario
# the events 'industry_salt' and 'doctor_advice' are defined in events.R
interventions <- list(
    list(), # No intervention, baseline
    list(industry_salt),
    list(doctor_advice)
)

# Define the relevant parameters for each scenario
intervention_pars <- list(
    list(),
    list(effect = 0.97),
    list(
        adherence = 0.5,
        effect = -2.0
    )
)

# Apply each scenario and record the results
# -- N : simulation time for each scenario
# -- M : number of replications for each scenario
N <- 10 * 365
M <- 100
I <- length(interventions)
init_stroke <- sum(sim$get_status()$stroke)
new_stroke <- matrix(0, M, I)
for (m in 1:M) {
    for (i in 1:I) {
        sim$reset(seeds = 10 * m + i)
        if (length(interventions[[i]])) {
            interventions[[i]][[1]]$set_parameters(intervention_pars[[i]])
            sim$intervene(man_events = interventions[[i]])
        }
        sim$run(N, show_progress = FALSE)
        new_stroke[m,i] <- sum(sim$get_status()$stroke) - init_stroke
    }
}

Events

The file events.R includes the event definitions

#####################
# Event definitions #
#####################

aging <- ManipulationEvent$new(
    name = "Aging",
    description = "Normal aging of individuals",
    parameters = list(
        day = 1.0 / 365.0
    ),
    mechanism = expression({
        status[alive == 1L, age := age + day]
    })
)

other_mortality <- ManipulationEvent$new(
    name = "Other Mortality",
    description = "Risk of death when no stroke has occurred",
    parameters = list(
        shape = 13.16346, 
        scale = 165.34514
    ),
    mechanism = expression({
        u <-  status[alive == 1L & stroke == 0L, rbern(1 - pweibull(age, shape = shape, scale = scale))]
        status[alive == 1L & stroke == 0L, `:=` (alive = u, cause_of_death = 1L - u)]
    })
)

get_stroke <- ManipulationEvent$new(
    name = "Stroke",
    description = "Risk of suffering a stroke",
    parameters = list(
                        # const  age    smoking  bp        hdl     diab.   parents' stroke
        beta_male   = c(-11.699, 0.1153, 0.4981, 0.0149,  -0.4406, 0.879,  0.2933),
        beta_female = c(-7.966,  0.0633, 0.4163, 0.00893, -0.7636, 1.2383, 0.5470)
    ),
    mechanism = expression({
        subset_male   <- status[alive == 1L & stroke == 0L & sex == 0L, which = TRUE]
        subset_female <- status[alive == 1L & stroke == 0L & sex == 1L, which = TRUE]
        X_male   <- as.matrix(status[subset_male,   
            .(const = 1, age, smoking, blood_pressure, hdl_cholesterol, diabetes, parent_stroke)])
        X_female <- as.matrix(status[subset_female, 
            .(const = 1, age, smoking, blood_pressure, hdl_cholesterol, diabetes, parent_stroke)])

        # Compute 1 day risk from 10 year risk by assuming a Poisson process.
        prob_male_10year   <- 1 - pmin(expit(as.vector(X_male %*% beta_male)), 1)
        prob_female_10year <- 1 - pmin(expit(as.vector(X_female %*% beta_female)), 1)
        lambda_male   <- -log(pmax(prob_male_10year, 1e-12)) / 3650.0
        lambda_female <- -log(pmax(prob_female_10year, 1e-12)) / 3650.0
        new_stroke_male   <- rbern(1 - exp(-lambda_male))
        new_stroke_female <- rbern(1 - exp(-lambda_female))

        # Assign dates for stroke occurrence, update stroke status
        status[subset_male,   `:=` (age_at_stroke = age * new_stroke_male,   stroke = new_stroke_male)]
        status[subset_female, `:=` (age_at_stroke = age * new_stroke_female, stroke = new_stroke_female)]
    })
)

get_diabetes <- ManipulationEvent$new(
    name = "Diabetes",
    description = "Risk of getting diabetes",
    parameters = list(
                # const   age(45-54) age(55-64) bmi(25-30) bmi(>30) w. circ. (1) w.circ. (2) h. glucose
        beta = c(-5.514,  0.628,     0.892,     0.165,     1.096,   0.857,       1.350,      2.139)
    ),
    mechanism = expression({
        subset <- status[alive == 1L & diabetes == 0L & age >= 45, which = TRUE]
        X <- as.matrix(status[subset, .(
            const = 1,
            age1 = (age >= 45 & age < 54), 
            age2 = (age >= 54), 
            bmi1 = (bmi > 25 & bmi <= 30), 
            bmi2 = (bmi > 30), 
            waist_circumference1 = (sex == 0L) * (waist_circumference >= 94 & waist_circumference < 102) +
                                   (sex == 1L) * (waist_circumference >= 80 & waist_circumference < 88),
            waist_circumference2 = (sex == 0L) * (waist_circumference >= 102) + 
                                   (sex == 1L) * (waist_circumference >= 88),
            high_glucose)
        ])

        # Compute 1 day risk from 10 year risk by assuming a Poisson process.
        prob_10year <- 1 - pmin(expit(as.vector(X %*% beta)), 1)
        lambda <- -log(pmax(prob_10year, 1e-12)) / 3650.0
        new_diabetes <- rbern(1 - exp(-lambda))

        # Update diabetes status
        status[subset, diabetes := new_diabetes]
    })
)

stroke_mortality <- ManipulationEvent$new(
    name = "Stroke Mortality",
    description = "Risk of death after suffering a stroke",
    parameters = list(
        initial = 5.41859, 
        longterm = c(8.17473693, -2.47553590, -0.01651868)
    ),
    mechanism = expression({
        u <- status[alive == 1L & stroke == 1L, rbern(fifelse(
            age - age_at_stroke < (28.0 / 365.0), 
            initial, 
            expit(longterm[1] + longterm[2] * exp(((age - age_at_stroke) * 365 - 27) * longterm[3]))))]
        status[alive == 1L & stroke == 1L, `:=` (alive = u, cause_of_death = 2L * (1L - u))]
    })
)

industry_salt <- ManipulationEvent$new(
    name = "Intervention 1",
    description = "Industry reduces salt content in all products",
    parameters = list(
        effect = -0.97
    ),
    mechanism = expression({
        status[alive == 1L, blood_pressure := blood_pressure + effect]
    })
)

doctor_advice <- ManipulationEvent$new(
    name = "Intervention 2",
    description = "Advice to stop adding salt in cooking and at table",
    parameters = list(
        adherence = 0.5,
        effect = -2
    ),
    mechanism = expression({
        status[alive == 1L & blood_pressure >= 140, 
               blood_pressure := blood_pressure + effect * rbern2(.N, adherence)]
    })
)

# Further event organizing
stroke_events  <- list(stroke_mortality, get_stroke)
attr(stroke_events, "ordered") <- TRUE
other_events <- list(stroke_events, other_mortality, get_diabetes)
all_events <- list(aging, other_events)
attr(all_events, "ordered") <- TRUE

Initial population

The file data.R contains the following function for generating the initial population (of size n)

initializer <- function(n) {
    # Seed for initial data generation
    set.seed(100)
    dqrng::dqset.seed(100)
    widths <- c(2,1,2,2,6,1,3,8,8,1,1,1,1,2,1,3,1,1,4,1,3,2,1,3,1,3,2,1,1,1,1,1,1,1,1,1,2,1,
                1,3,3,2,3,3,2,1,2,2,4,2,3,3,8,3,3,8,3,4,2,3,4,4,4,2,1,2,1,1,1,1,1,4,4,4,4)
    colnames <- c("form","versn","centre","runit","serial","numsur","samunit","dexam","mbirth",
                  "agegr","sex","marit","edlevel","school","cigs","numcigs","daycigs","evercig",
                  "stop","iflyear","maxcigs","cigage","cigarsm","cigar","pipesm","pipe","othersm",
                  "hibp","drugs","bprecd","hich","chdt","chrx","chrecd","asp","menop","agem",
                  "horm","pill","syst1","diast1","rz1","syst2","diast2","rz2","cuff","arm",
                  "bpcoder","timebp","rtemp","chol","choldl","dchol","hdl","hdldl","dhdl","scn",
                  "cotin","carbmon","height","weight","waist","hip","whcoder","oversion","eage",
                  "eageg","cohort1","cohort2","edtert1","edtert2","systm","diastm","chola","hdla")
    monica <- read.fwf("data/form04_3.txt", widths = widths, col.names = colnames)

    age_structure <- read.csv2(file = "data/vaesto_3112_2017.csv", header = TRUE)
    sex <- rbin(n)
    mf <- age_structure[,2:3] 
    age_structure[,2:3] <- mf / colSums(mf)[col(mf)]
    age <- (sex == 0L) * sample(30:99, n, replace = TRUE, prob = age_structure[,2]) +
           (sex == 1L) * sample(30:99, n, replace = TRUE, prob = age_structure[,3])
    smoking <- sex_age_init(n, sex, age, 
                            c(0.319, 0.298, 0.278, 0.219, 0.109, 0.016),
                            c(0.229, 0.193, 0.184, 0.147, 0.054, 0.016), 30, 80, 10)

    parent_stroke <- rbin(n, prob = 0.127)
    high_glucose <- sex_age_init(n, sex, age, 
                                 c(0.013, 0.043, 0.080, 0.114, 0.128, 0.110), 
                                 c(0.019, 0.012, 0.030, 0.055, 0.121, 0.093), 30, 80, 10)
    stroke <- sex_age_init(n, sex, age, 
                           c(0.027, 0.063, 0.091, 0.163), 
                           c(0.013, 0.029, 0.061, 0.137), 50, 80, 10)
    age_at_stroke <- rep(0.0, n)
    age_at_stroke[stroke == 1L] <- pmax(age[stroke == 1L] - 
        (rweibull(sum(stroke), shape = 0.509246, scale = 1634.480968)) / 365.0, 30)

    # decoding missing values
    monica$height[monica$height == 999] <- NA
    monica$weight[monica$weight == 9999] <- NA
    monica$waist[monica$waist == 9999] <- NA
    monica$systm[monica$systm == 9999] <- NA
    monica$chol[monica$chol >= 888] <- NA
    monica$hdl[monica$hdl >= 777] <- NA

    # unit transformations
    monica$weight <- monica$weight / 10 # transform to kg
    monica$waist <- monica$waist / 10 # transform to cm
    monica$systm <- monica$systm / 10 # transform to mmHg
    monica$chol <- monica$chol / 10 # transform to mmol/l
    monica$hdl <- monica$hdl / 100 # transform to mmol/l
    monica$bmi <- monica$weight / (monica$height / 100)^2
    monica$sex <- monica$sex - 1 # female = 1
    #monica$smoking <- 0 + (monica$cigs %in% c(1,3))
    monica$smoking <- 0 + (monica$cigs %in% c(1, 3) | 
                           monica$cigarsm %in% c(1, 3) | 
                           monica$pipesm %in% c(1, 3))

    selcol <- c("bmi", "waist", "chol", "hdl", "systm")
    m <- length(selcol)
    monica2 <- monica[,c("eage", "sex", "smoking", selcol)]
    out <- c(which.min(monica2$waist), which.max(monica2$chol))
    monica2 <- monica2[-out,]
    monica2$agegr <- cut(monica2$eage, breaks = c(0, 30, 40, 50, 60, 70), right = FALSE)

    # Generate initial values in each subgroup of sex, smoking and age
    norm_grp <- expand.grid(sex = 0:1, smoking = 0:1, agegr = 1:5)
    k <- nrow(norm_grp)
    norm_dat <- vector(mode = "list", length = k)
    norm_par <- vector(mode = "list", length = k)

    for (i in 1:k) {
        ind <- (monica2$sex == norm_grp[i,1] & 
                monica2$smoking == norm_grp[i,2] & 
                as.numeric(monica2$agegr) == norm_grp[i,3])
        ind[is.na(ind)] <- FALSE
        norm_par[[i]]$trans <- list()
        norm_dat[[i]] <- matrix(0, sum(ind), m)
        for (j in 1:m) {
            bc <- bestNormalize::boxcox(monica2[ind,selcol[j]])
            norm_par[[i]]$trans[[j]] <- bc
            norm_dat[[i]][,j] <- bc$x.t
        }
        norm_par[[i]]$mu <- colMeans(norm_dat[[i]], na.rm = TRUE)
        norm_par[[i]]$Sigma <- cov(norm_dat[[i]], use = "complete.obs")
    }

    X <- matrix(0, n, m)

    age_mon <- as.numeric(cut(age, breaks = c(0, 30, 40, 50, 60, 100), right = FALSE))
    for (i in 5:k) {
        ind <- (sex == norm_grp[i,1] &
                smoking == norm_grp[i,2] & 
                age_mon == norm_grp[i,3])
        temp <- MASS::mvrnorm(sum(ind), mu = norm_par[[i]]$mu, Sigma = norm_par[[i]]$Sigma)
        for (j in 1:m) {
            temp[,j] <- predict(norm_par[[i]]$trans[[j]], temp[,j], inverse = TRUE)
        }
        X[ind,] <- temp
    }

    # Limits data to values that have been actually observed
    X <- impose_limits(X, sex, smoking, age_mon, 
                       list(c(16, 72.5), c(50, 152), c(2.5, 11.5), c(0.37, 3.62), c(78, 252)))

    bmi <- X[,1]
    waist_circumference <- X[,2]
    cholesterol <- X[,3]
    hdl_cholesterol <- X[,4]
    blood_pressure <- X[,5]

    # Generating diabetes status based on FINDRISC risk score
    diabetes_par <- c(-5.514, 0.628, 0.892, 0.165, 1.096, 0.857, 1.350, 2.139)
    D <- matrix(0, n, 8)
    D[,1] <- 1
    D[,2] <- (age >= 45 & age < 54)
    D[,3] <- (age >= 54)
    D[,4] <- (bmi > 25 & bmi <= 30)
    D[,5] <- (bmi > 30)
    D[,6] <- (sex == 0L) * (waist_circumference >= 94 & waist_circumference < 102) +
             (sex == 1L) * (waist_circumference >= 80 & waist_circumference < 88)
    D[,7] <- (sex == 0L) * (waist_circumference >= 102) + 
             (sex == 1L) * (waist_circumference >= 88)
    D[,8] <- high_glucose
    diabetes_risk <- expit(as.numeric(D %*% diabetes_par))
    diabetes_risk[sex == 0L] <- diabetes_risk[sex == 0L] / mean(diabetes_risk[sex == 0L]) * 0.15
    diabetes_risk[sex == 1L] <- diabetes_risk[sex == 1L] / mean(diabetes_risk[sex == 1L]) * 0.10
    diabetes <- rbern(diabetes_risk)

    return(data.table(
        alive = rep(1L, n),
        age = as.double(age),
        age_at_start = as.double(age),
        age_at_stroke = age_at_stroke,
        sex = sex,
        smoking = smoking,
        diabetes = diabetes,
        parent_stroke = parent_stroke,
        high_glucose = high_glucose,
        stroke = stroke,
        bmi = bmi,
        waist_circumference = waist_circumference,
        cholesterol = cholesterol,
        hdl_cholesterol = hdl_cholesterol,
        blood_pressure = blood_pressure,
        cause_of_death = rep(0L, n),
        id = 1:n
    ))

}

Utility functions

The file functions.R provides a set of useful functions used in defining events and the initial data

## Bernoulli distribution with probability of success (1) 'prob' (vectorized over prob)
rbern <- function(prob) {
    return(1L * (dqrng::dqrunif(length(prob)) < prob))
}

## Bernoulli distribution for 'n' repetitions with the same probability 'prob'
rbern2 <- function(n, prob) {
    return(1L * (dqrng::dqrunif(n) < prob))
}

## Uniformly random date between 'start' and 'end' dates (inclusive)
rdate <- function(start, end, ...) {
    return(sample(seq(as.Date(start), as.Date(end), by = "day"), ...))
}

## Random binary, where the probability of 1 is 'prob'
rbin <- function(size, prob = 0.5) {
    return(sample(c(0L, 1L), size = size, replace = TRUE, prob = c(1 - prob, prob)))
}

## The logistic link function
logit <- function(x) log(x) - log(1 - x)

## Inverse logistic link
expit <- function(x) (1 + exp(-x))^(-1)

## Function to generate binary starting values based on age-based proportions
## n        : number of values to generate
## sex      : vector of sex for each individual
## age      : vector of ages (in years)
## val_m    : proportions for males in each age group
## val_f    : proportions for females in each age group
## age_min  : youngest age group, no individual should have an age below this
## age_max  : oldest distinct age group, ages above this are considered a single group
## age_step : size of each age group
sex_age_init <- function(n, sex, age, val_m, val_f, age_min, age_max, age_step) {
    groups <- (age_max - age_min) %/% age_step + 1
    age_group <- pmin((age - age_min) %/% age_step + 1, groups)
    age_group[age_group < 1] <- NA
    p <- (age_group > 0) * ((sex == 0) * val_m[age_group] + (sex == 1) * val_f[age_group])
    p[is.na(p)] <- 0
    return(rbinom(n, 1, p))
}

## Function to resample values within an allowed range
## X       : a matrix of values
## sex     : a vector of sex indicators
## smoking : a vector of smoking indicators
## age     : a vector of age groups
## limits  : a list of allowed ranges for column of 'X'
impose_limits <- function(X, sex, smoking, age, limits) {
    age <- as.numeric(age)
    groups <- expand.grid(unique(sex), unique(smoking), unique(age))
    for (i in 1:nrow(groups)) {
        g <- sex == groups[i,1] & smoking == groups[i,2] & age == groups[i,3]
        for (j in 1:ncol(X)) {
            out <- X[g,j] < limits[[j]][1] | X[g,j] > limits[[j]][2]
            bound <- !out
            g_out <- g
            g_out[g] <- out
            g_bound <- g
            g_bound[g] <- !out
            X[g_out,j] <- sample(X[g_bound,j], sum(out), replace = TRUE)
        }
    }
    return(X)
}

## Objective function to calibrate total mortality to mortality of the Finnish population
## x : current vector of parameter values
mortality_objective <- function(x, calib_sim) {
    mortality <- read.csv(file = "data/kuol_007_201700.csv", header = TRUE, sep = ",")
    mortality[,3] <- mortality[,3]/1000
    n <- nrow(mortality)
    # Impute 100-year-olds from 99
    mortality <- rbind(mortality, c(100, 2017, 2 * mortality[n,3] - mortality[n-1,3])) 
    p <- list("Other Mortality" = x[1:2]^2, "Stroke Mortality" = list(x[3]^2, 1))
    calib_res <- calib_sim$configure(t_sim = 364, p = p, output = mortality_df, min_age = 30)
    out <- calib_res$out
    stroke_prop <- calib_res$stroke_prop
    s_obj <- 1e7
    if (stroke_prop > 0) s_obj <- 2000 * (log(stroke_prop) - log(0.07534936))^2
    mortality <- mortality[mortality$Age %in% out$age_group,]
    w <- rep(1, nrow(out))
    w[out$age_group >= 80] <- 100 # Weight for mortalities >= 60, fit should be better here
    valid <- (out[ ,2] != 0 & out[ ,1] != out[,2] & out[ ,1] != 0)
    all_dead <- out[ ,1] == out[ ,2]
    all_alive <- out[ ,1] == 0
    obj <- sum(w[valid] * (log(mortality[valid,3]) - log(out[valid,1]/out[valid,2]))^2)
    obj <- obj + sum(w[all_dead] * log(mortality[all_dead,3])^2)
    obj <- obj + sum(w[all_alive] * log(mortality[all_alive,3])^2)
    return(obj + s_obj)
}

## A function used to calibrate total mortality
## dt        : the status data.table
## min_age   : the youngest age group
mortality_df <- function(dt, min_age = 0) {
    dt[ ,age_group := floor(age)]
    deaths <- dt[alive == 0, .N, by = age_group]
    size <- dt[ , .N, by = age_group]
    groups <- 0:max(dt$age_group)
    ages <- length(groups)
    out <- data.frame(deaths = rep(0, ages), size = rep(0, ages), age_group = groups)
    out[size$age_group + 1,"size"] <- size$N
    if (any(deaths$N > 0)) {
        out[deaths$age_group + 1,"deaths"] <- deaths$N
    }
    out <- out[out$age_group >= min_age, ]
    stroke_prop <- nrow(dt[alive == 0 & stroke == 1 & 
                               (age - age_at_stroke) < (28/365),]) / nrow(dt[alive == 0,])
    return(list(out = out, stroke_prop = stroke_prop))
}