## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----message=FALSE------------------------------------------------------------
require(causalBatch)
require(ggplot2)
require(ggpubr)
require(tidyr)
n = 200

## ----eval=FALSE---------------------------------------------------------------
#  vignette("cb.simulations", package="causalBatch")

## -----------------------------------------------------------------------------
# a function for plotting a scatter plot of the data
plot.sim <- function(Ys, Ts, Xs, title="", 
                     xlabel="Covariate",
                     ylabel="Outcome (1st dimension)") {
  data = data.frame(Y1=Ys[,1], Y2=Ys[,2], 
                    Group=factor(Ts, levels=c(0, 1), ordered=TRUE), 
                    Covariates=Xs)
  
  data %>%
    ggplot(aes(x=Covariates, y=Y1, color=Group)) +
    geom_point() +
    labs(title=title, x=xlabel, y=ylabel) +
    scale_x_continuous(limits = c(-1, 1)) +
    scale_color_manual(values=c(`0`="#bb0000", `1`="#0000bb"), 
                       name="Group/Batch") +
    theme_bw()
}

## ----fig.width=5, fig.height=3------------------------------------------------
sim.simpl = cb.sims.sim_sigmoid(n=n, eff_sz=1, unbalancedness=1.5)

plot.sim(sim.simpl$Ys, sim.simpl$Ts, sim.simpl$Xs, title="Sigmoidal Simulation")

## ----eval=FALSE---------------------------------------------------------------
#  vignette("cb.detect.caus_cdcorr", package="causalBatch")

## -----------------------------------------------------------------------------
result <- cb.detect.caus_cdcorr(sim.simpl$Ys, sim.simpl$Ts, sim.simpl$Xs, R=100)
print(sprintf("p-value: %.4f", result$Test$p.value))

## -----------------------------------------------------------------------------
cor.sim.simpl <- cb.correct.matching_cComBat(sim.simpl$Ys, sim.simpl$Ts, 
                                         data.frame(Covar=sim.simpl$Xs),
                                         match.form="Covar")

## ----fig.width=5, fig.height=3------------------------------------------------
plot.sim(cor.sim.simpl$Ys.corrected, cor.sim.simpl$Ts, cor.sim.simpl$Xs$Covar,
         title="Sigmoidal Simulation (matching cComBat corrected)")

## -----------------------------------------------------------------------------
result.cor <- cb.detect.caus_cdcorr(cor.sim.simpl$Ys.corrected, cor.sim.simpl$Ts,
                                    cor.sim.simpl$Xs$Covar, R=100)
print(sprintf("p-value: %.4f", result.cor$Test$p.value))

## -----------------------------------------------------------------------------
Xs.2covar <- data.frame(Covar1=sim.simpl$Xs, Covar2=runif(n))

## -----------------------------------------------------------------------------
cor.sim <- cb.correct.matching_cComBat(sim.simpl$Ys, sim.simpl$Ts, Xs.2covar, 
                                   match.form="Covar1 + Covar2")

## -----------------------------------------------------------------------------
Xs.3covar <- cbind(data.frame(Cat.Covar=factor(rbinom(n, size=1, 0.5))), 
                   Xs.2covar)

## -----------------------------------------------------------------------------
match.args <- list(method="nearest", exact="Cat.Covar", replace=FALSE, 
                   caliper=0.1)
cor.sim <- cb.correct.matching_cComBat(sim.simpl$Ys, sim.simpl$Ts, Xs.3covar, 
                                   match.form="Covar1 + Covar2 + Cat.Covar",
                                   match.args=match.args)

## -----------------------------------------------------------------------------
# a function for plotting a histogram plot of the covariates
plot.covars <- function(Ts, Xs, title="", 
                     xlabel="Covariate",
                     ylabel="Number of Samples") {
  data = data.frame(Covariates=as.numeric(Xs),
                    Group=factor(Ts, levels=c(0, 1), ordered=TRUE))
  
  data %>%
    ggplot(aes(x=Covariates, color=Group, fill=Group)) +
    geom_histogram(position="identity", alpha=0.5) +
    labs(title=title, x=xlabel, y=ylabel) +
    scale_x_continuous(limits = c(-1, 1)) +
    scale_y_continuous(limits=c(0, 12)) +
    scale_color_manual(values=c(`0`="#bb0000", `1`="#0000bb"), 
                       name="Group/Batch") +
    scale_fill_manual(values=c(`0`="#bb0000", `1`="#0000bb"), 
                       name="Group/Batch") +
    theme_bw()
}

ggarrange(plot.covars(sim.simpl$Ts, sim.simpl$Xs, title="(A) Unfiltered samples"),
          plot.covars(cor.sim.simpl$Ts, cor.sim.simpl$Xs$Covar,
                      title="(B) Matched + Trimmed samples"),
          nrow=2)

## ----fig.width=5, fig.height=4.5----------------------------------------------
cor.sim.oos <- cb.correct.matching_cComBat(sim.simpl$Ys, sim.simpl$Ts, data.frame(Covar=sim.simpl$Xs), 
                                       match.form="Covar", apply.oos=TRUE)

ggarrange(plot.covars(cor.sim.oos$Ts, cor.sim.oos$Xs$Covar, 
                      title="(A) In- and out-of-sample data"),
          plot.sim(cor.sim.oos$Ys.corrected, cor.sim.oos$Ts, cor.sim.oos$Xs$Covar,
                   title="(B) matching cComBat on in- and out-of-sample data"),
          nrow=2)


## ----fig.width=5, fig.height=3------------------------------------------------
oos.ids <- cb.align.vm_trim(sim.simpl$Ts, sim.simpl$Xs)
Ys.oos <- sim.simpl$Ys[oos.ids,,drop=FALSE]; Ts.oos <- sim.simpl$Ts[oos.ids]
Xs.oos <- sim.simpl$Xs[oos.ids,,drop=FALSE]

Ys.oos.cor <- cb.correct.apply_cComBat(Ys.oos, Ts.oos, data.frame(Covar=Xs.oos),
                                       cor.sim.oos$Model)

plot.sim(Ys.oos.cor, Ts.oos, Xs.oos, title=" matching cComBat applied to OOS data")