Skip to contents

We can estimate ITR with various machine learning algorithms and then compare the performance of each model. The package includes all ML algorithms in the caret package and 2 additional algorithms (causal forest and bartCause).

The package also allows estimate heterogeneous treatment effects on the individual and group-level. On the individual-level, the summary statistics and the AUPEC plot show whether assigning individualized treatment rules may outperform complete random experiment. On the group-level, we specify the number of groups through ngates and estimating heterogeneous treatment effects across groups.

library(evalITR)

# specify the trainControl method
fitControl <- caret::trainControl(
                           method = "repeatedcv",
                           number = 3,
                           repeats = 3)
# estimate ITR
set.seed(2021)
fit_cv <- estimate_itr(
               treatment = "treatment",
               form = user_formula,
               data = star_data,
               trControl = fitControl,
               algorithms = c(
                  "causal_forest", 
                  "bartc",
                  # "rlasso", # from rlearner 
                  # "ulasso", # from rlearner 
                  "lasso", # from caret package
                  "rf"), # from caret package
               budget = 0.2,
               n_folds = 3)
#> Evaluate ITR with cross-validation ...
#> fitting treatment model via method 'bart'
#> fitting response model via method 'bart'
#> fitting treatment model via method 'bart'
#> fitting response model via method 'bart'
#> fitting treatment model via method 'bart'
#> fitting response model via method 'bart'

# evaluate ITR
est_cv <- evaluate_itr(fit_cv)

# summarize estimates
summary(est_cv)
#> ── PAPE ────────────────────────────────────────────────────────────────────────
#>   estimate std.deviation     algorithm statistic p.value
#> 1    0.954          0.82 causal_forest     1.168    0.24
#> 2   -0.028          0.44         bartc    -0.064    0.95
#> 3    0.173          1.07         lasso     0.162    0.87
#> 4    1.266          0.95            rf     1.335    0.18
#> 
#> ── PAPEp ───────────────────────────────────────────────────────────────────────
#>   estimate std.deviation     algorithm statistic p.value
#> 1     2.55          0.65 causal_forest      3.91 9.2e-05
#> 2     1.75          0.90         bartc      1.95 5.2e-02
#> 3    -0.21          0.63         lasso     -0.33 7.4e-01
#> 4     1.69          1.11            rf      1.52 1.3e-01
#> 
#> ── PAPDp ───────────────────────────────────────────────────────────────────────
#>   estimate std.deviation             algorithm statistic p.value
#> 1    0.803          0.94 causal_forest x bartc     0.853 0.39353
#> 2    2.760          0.80 causal_forest x lasso     3.458 0.00054
#> 3    0.868          0.71    causal_forest x rf     1.219 0.22292
#> 4    1.958          1.06         bartc x lasso     1.848 0.06453
#> 5    0.065          1.13            bartc x rf     0.057 0.95427
#> 6   -1.893          0.72            lasso x rf    -2.615 0.00892
#> 
#> ── AUPEC ───────────────────────────────────────────────────────────────────────
#>   estimate std.deviation     algorithm statistic p.value
#> 1     1.43           1.5 causal_forest      0.92    0.36
#> 2     0.81           1.4         bartc      0.58    0.56
#> 3     0.18           1.4         lasso      0.13    0.90
#> 4     1.37           1.6            rf      0.88    0.38
#> 
#> ── GATE ────────────────────────────────────────────────────────────────────────
#>    estimate std.deviation     algorithm group statistic p.value upper lower
#> 1    -118.1            59 causal_forest     1    -2.013   0.044  -3.1  -233
#> 2      27.0            59 causal_forest     2     0.454   0.650 143.5   -90
#> 3      60.9            59 causal_forest     3     1.034   0.301 176.4   -55
#> 4       7.6            59 causal_forest     4     0.128   0.898 123.7  -109
#> 5      40.9            99 causal_forest     5     0.411   0.681 235.8  -154
#> 6      28.7            80         bartc     1     0.357   0.721 186.3  -129
#> 7    -145.7            84         bartc     2    -1.737   0.082  18.7  -310
#> 8      51.5            99         bartc     3     0.522   0.601 245.0  -142
#> 9      40.2            59         bartc     4     0.681   0.496 155.9   -76
#> 10     43.4            87         bartc     5     0.498   0.619 214.5  -128
#> 11    -14.4            94         lasso     1    -0.154   0.878 169.2  -198
#> 12    -94.5            90         lasso     2    -1.051   0.293  81.8  -271
#> 13     87.9            99         lasso     3     0.886   0.376 282.4  -107
#> 14     12.6            59         lasso     4     0.214   0.830 127.8  -103
#> 15     26.6            59         lasso     5     0.451   0.652 142.4   -89
#> 16    -37.4            59            rf     1    -0.638   0.523  77.5  -152
#> 17     10.6            59            rf     2     0.180   0.857 126.5  -105
#> 18    -17.6            59            rf     3    -0.299   0.765  97.7  -133
#> 19     66.5            86            rf     4     0.770   0.441 235.9  -103
#> 20     -3.9            60            rf     5    -0.066   0.948 113.0  -121

We plot the estimated Area Under the Prescriptive Effect Curve for the writing score across different ML algorithms.

# plot the AUPEC with different ML algorithms
plot(est_cv)