We can train the model with the caret
package (for
further information about caret
, see the original
website). We use parallel computing to speed up the computation.
# parallel computing
library(doParallel)
cl <- makePSOCKcluster(5)
registerDoParallel(cl)
# stop after finishing the computation
stopCluster(cl)
The following example shows how to estimate the ITR with grandient
boosting machine (GBM) using the caret
package. Note that
we have already loaded the data and specify the treatment, outcome, and
covariates as shown in the Sample Splitting vignette.
Since we are using the caret
package, we need to specify
the trainControl
and/or tuneGrid
arguments.
The trainControl
argument specifies the cross-validation
method and the tuneGrid
argument specifies the tuning grid.
For more information about these arguments, please refer to the caret
website.
We estimate the ITR with only one machine learning algorithm (GBM)
and evaluate the ITR with the evaluate_itr()
function. To
compute PAPDp
, we need to specify the
algorithms
argument with more than 2 machine learning
algorithms.
library(evalITR)
# specify the trainControl method
fitControl <- caret::trainControl(
method = "repeatedcv", # 3-fold CV
number = 3, # repeated 3 times
repeats = 3,
search='grid',
allowParallel = TRUE) # grid search
# specify the tuning grid
gbmGrid <- expand.grid(
interaction.depth = c(1, 5, 9),
n.trees = (1:30)*50,
shrinkage = 0.1,
n.minobsinnode = 20)
# estimate ITR
fit_caret <- estimate_itr(
treatment = "treatment",
form = user_formula,
trControl = fitControl,
data = star_data,
algorithms = c("gbm"),
budget = 0.2,
split_ratio = 0.7,
tuneGrid = gbmGrid,
verbose = FALSE)
#> Evaluate ITR under sample splitting ...
# evaluate ITR
est_caret <- evaluate_itr(fit_caret)
#> Cannot compute PAPDp
We can extract the training model from caret
and check
the model performance. Other functions from caret
can be
applied to the training model.
# extract the final model
caret_model <- fit_caret$estimates$models$gbm
print(caret_model$finalModel)
#> A gradient boosted model with gaussian loss function.
#> 50 iterations were performed.
#> There were 53 predictors of which 36 had non-zero influence.
# check model performance
trellis.par.set(caretTheme()) # theme
plot(caret_model)
Thesummary()
function displays the following summary
statistics: (1) population average prescriptive effect
PAPE
; (2) population average prescriptive effect with a
budget constraint PAPEp
; (3) population average
prescriptive effect difference with a budget constraint
PAPDp
. This quantity will be computed with more than 2
machine learning algorithms); (4) and area under the
prescriptive effect curve AUPEC
. For more information about
these evaluation metrics, please refer to Imai and Li (2021); (5)
Grouped Average Treatment Effects GATEs
. The details of the
methods for this design are given in Imai and Li (2022).
# summarize estimates
summary(est_caret)
#> ── PAPE ────────────────────────────────────────────────────────────────────────
#> estimate std.deviation algorithm statistic p.value
#> 1 -0.35 1.5 gbm -0.24 0.81
#>
#> ── PAPEp ───────────────────────────────────────────────────────────────────────
#> estimate std.deviation algorithm statistic p.value
#> 1 1.6 1.3 gbm 1.2 0.21
#>
#> ── PAPDp ───────────────────────────────────────────────────────────────────────
#> data frame with 0 columns and 0 rows
#>
#> ── AUPEC ───────────────────────────────────────────────────────────────────────
#> estimate std.deviation algorithm statistic p.value
#> 1 0.22 1.1 gbm 0.19 0.85
#>
#> ── GATE ────────────────────────────────────────────────────────────────────────
#> estimate std.deviation algorithm group statistic p.value upper lower
#> 1 105 109 gbm 1 0.96 0.34 -75 285
#> 2 -60 108 gbm 2 -0.56 0.58 -238 117
#> 3 -139 107 gbm 3 -1.30 0.19 -315 37
#> 4 64 108 gbm 4 0.59 0.55 -114 243
#> 5 51 109 gbm 5 0.47 0.64 -128 230
We plot the estimated Area Under the Prescriptive Effect Curve for the writing score across a range of budget constraints for the gradient boosting machine.
# plot the AUPEC
plot(est_caret)