## ----echo = FALSE, warning=FALSE----------------------------------------------
library(YEAB)

## -----------------------------------------------------------------------------
set.seed(43)
l1 <- 1 / 10
l2 <- 1 / 40
p <- 0.4
n <- 200
delta <- 0.03
irt <- c(
  rexp(round(n * p), l1),
  rexp(round(n * (1 - p)), l2)
) + delta
plot(irt)

## -----------------------------------------------------------------------------
biexponential(irt)

## -----------------------------------------------------------------------------
berm(irt, 0.03)

## -----------------------------------------------------------------------------
library(ks)
library(splines)

# Function to compute the survival function using KDE with a smoothed tail
compute_survival_kde <- function(log_data, n_points = 100, log_transform = TRUE) {
  # Apply log transformation if specified
  if (log_transform) log_data <- log(log_data + 1e-5) # Small constant to avoid log(0)

  # Compute KDE
  kde <- kde(log_data)
  x <- seq(min(log_data), max(log_data), length.out = n_points)
  pdf <- predict(kde, x = x) # PDF from KDE
  cdf <- cumsum(pdf) / sum(pdf) # Approximate CDF
  survival <- 1 - cdf # Survival function

  # Apply spline smoothing to the tail of the survival function
  smooth_survival <- smooth.spline(x, survival, spar = 0.7) # Adjust spar as needed
  data.frame(x = exp(x) - 1e-5, survival = predict(smooth_survival, x)$y) # Transform back if log-transformed
}

# Function to simulate and compute survival function with smoothing
simulate_survival <- function(params, model_type, delta = NULL, n = 100) {
  if (model_type == "biexponential") {
    q <- params["w"]
    l0 <- params["l0"]
    l1 <- params["l1"]
    n1 <- round(n * q)
    n2 <- n - n1
    irt_sim <- c(rexp(n1, rate = 1 / l0), rexp(n2, rate = 1 / l1))
  } else if (model_type == "berm") {
    q <- params["w"]
    l0 <- params["l0"]
    l1 <- params["l1"]
    delta <- params["d"]
    n1 <- round(n * q)
    n2 <- n - n1
    irt_sim <- c(rexp(n1, rate = 1 / l0) + delta, rexp(n2, rate = 1 / l1) + delta)
  } else {
    stop("Invalid model type. Choose 'biexponential' or 'berm'.")
  }
  # Compute KDE-based survival function for the simulated IRTs
  kde <- kde(irt_sim)
  x <- seq(min(irt_sim), max(irt_sim), length.out = n)
  pdf <- predict(kde, x = x)
  cdf <- cumsum(pdf) / sum(pdf)
  survival <- 1 - cdf

  # Smooth the tail of the survival function using spline
  smooth_survival <- smooth.spline(x, survival, spar = 0.7)
  data.frame(x = x, survival = predict(smooth_survival, x)$y)
}

# Function to plot survival functions for original data and simulations
plot_survival_comparison <- function(
    log_data,
    berm_params,
    biexponential_params,
    n_points = 200,
    num_sims = 100) {
  # Compute survival function for the original data
  survival_orig <- compute_survival_kde(log_data, n_points = n_points)

  # Plot original data survival function in blue
  plot(survival_orig$x,
    survival_orig$survival,
    type = "l", col = "blue", lwd = 2,
    xlab = "Inter-Response Time",
    ylab = "p(IRT > x)",
    main = "Survival Function Comparison",
    log = "y",
    ylim = c(0.0001, 1)
  )

  # Simulate and plot survival functions for the biexponential model (in yellow)
  for (i in 1:num_sims) {
    survival_biexp <- simulate_survival(
      biexponential_params,
      "biexponential",
      n = length(log_data)
    )
    lines(survival_biexp$x, survival_biexp$survival,
      col = rgb(251, 107, 95, 60, maxColorValue = 255),
      lwd = 0.5
    ) # Semi-transparent yellow
  }

  # Simulate and plot survival functions for the BERM model (in green)
  for (i in 1:num_sims) {
    survival_berm <- simulate_survival(berm_params, "berm", n = length(log_data))
    lines(survival_berm$x,
      survival_berm$survival,
      col = rgb(59, 132, 23, alpha = 52, maxColorValue = 255), lwd = 0.5
    ) # Semi-transparent green
  }
  leg_green <- rgb(59, 132, 23, alpha = 255, maxColorValue = 255)
  leg_red <- rgb(251, 107, 95, alpha = 255, maxColorValue = 255)
  legend("topright",
    legend = c("Original Data", "Biexponential Model", "BERM Model"),
    col = c("blue", leg_red, leg_green), lty = 1, lwd = 2
  )
}

# Example usage with generated data and parameters
set.seed(43)
# Original data generation
l1 <- 1 / 20
l2 <- 1 / 2
p <- 0.07
n <- 400
delta <- 0.03
irt <- c(rexp(round(n * p), l1), rexp(round(n * (1 - p)), l2)) + delta

# Parameters (replace with optimized values from your model fitting process)
biexponential_params <- as.numeric(biexponential(irt))
names(biexponential_params) <- c("w", "l0", "l1")

berm_params <- as.numeric(berm(irt, delta))
names(berm_params) <- c("w", "l0", "l1", "d")

# Plot comparison of survival functions
plot_survival_comparison(irt, berm_params, biexponential_params, num_sims = 100, n_points = 300)