Tree-based methods
Extensions to tree (Ensembles)
Interpretable ML
Causal Forest treatment effect
Classical regression
Lasso / Ridge
Trees / Ensembles
Regression:
The Outcome is numeric (e.g., blood pressure).
Model output: prediction \(\hat{y}\).
Loss:
\[ \mathbf{MSE} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{f}(x_i))^2 \]
Classification:
The outcome is class (categorical) (e.g., disease yes/no).
Model output: class probabilities \(\hat{p}(y=k|x)\), for \(k = 1, 2, .. K\) or class via threshold.
Model Accuracy / Loss: misclassification error rate, Gini, cross-entropy (deviance).
Regression tree example:
Classification tree example:
Root node: the top of the tree; all data before any splits
Internal nodes: points where the data is split based on a rule
Leaves (terminal nodes): final regions \(R_1, R_2, \dots\) where predictions are made
Trees are usually drawn with the root at the top and the leaves at the bottom.
Partition: Divide the predictor space (the set of possible values of \(X_1, X_2, \dots, X_p\)) into \(J\) distinct, non-overlapping regions \(R_1, R_2, \dots, R_J\).
Predict within a leaf: For every observation that falls into the region \(R_m\) , we make the same prediction.
Regression: mean outcome for each region \[\hat{\mu}_m=\dfrac{1}{N_m}\sum_{i\in R_m} y_i\]
Classification: proportion of each class \(k\)
\[\hat p_{mk}=\dfrac{1}{N_m}\sum_{i\in R_m}\mathbf{1}\{y_i=k\}\]
prediction is the class with the most observations (equivalently, highest proportion)
\[\arg\max_k \hat p_{mk}\]
We pick the tree with the lowest prediction error.
It is computationally infeasible to consider every possible partition of the feature space into \(J\) boxes.
Instead, we use a top-down, greedy algorithm to find the best split at each internal node of the tree.
For each predictor \(x_j\)
\[ R_L=\{x_{ij}<s\,\},\qquad R_R=\{x_{ij}\ge s\,\}. \]
\[ \mathrm{SSE}(j,s)=\sum_{i\in R_L}(y_i-\bar y_L)^2+\sum_{i\in R_R}(y_i-\bar y_R)^2, \]
where \(\bar y_L\) and \(\bar y_R\) are means in the two children.
\[ \mathrm{Imp}(j,s)=\frac{N_L}{N}\,\mathrm{Imp}(R_L)+\frac{N_R}{N}\,\mathrm{Imp}(R_R), \] where \(N_L\), \(N_R\), and \(N\) are the numbers of observations in the left child, right child, and parent (root), respectively.
with \(\mathrm{Imp}()\) either
\[ \mathrm{Gini}(R)=1-\sum_k \hat p_k^2 \qquad\text{or}\\ \mathrm{Entropy}(R)=-\sum_k \hat p_k\log \hat p_k. \] Misclassification error \(1-\max_k \hat p_k\) is less sensitive for growing trees.
\[ (j^*,s^*)=\arg\min_{(j,s)}\ \text{error}(j,s). \]
Salary is color-coded from low (light blue) to high (dark blue).
A five-region example is shown in the next slide.
Deep trees fit the training data very well but often overfit (poor test performance). Why?
A smaller tree (fewer regions \(R_1,\dots,R_J\)) usually has lower variance and is easier to interpret (at a small bias cost).
Grow the tree only while each split reduces RSS/SSE (or impurity) by more than a threshold.
Issue: short-sighted — an early “weak” split can enable a later strong split.
Grow a large tree \(T_0\), then prune back to a subtree.
Use cost-complexity pruning: minimize
\[ C_\alpha(T) \;=\; R(T) \;+\; \alpha\,|T|, \]
where \(R(T)\) is training loss
\[C_\alpha (T) = \sum_{m=1}^{|T|} \sum_{i; x_i \in R_m} (y_i - \hat{y}_{R_m})^2 + \alpha |T|,\]
\[C_\alpha (T) = \sum_{m=1}^{|T|} \sum_{i; x_i \in R_m} \mathrm{Imp}(R_m) + \alpha |T|,\]
and \(|T|\) is the number of leaves.
The tuning parameter \(\alpha\) controls a trade-off between the subtree’s complexity and its fit to the training data.
The tuning parameter \(\alpha\) can be chosen using cross-validation.
We randomly split the data 50/50: 132 training and 131 test observations.
We fit a large regression tree on the training data and varied the cost-complexity parameter \(\alpha\) to create subtrees with different numbers of terminal nodes (leaves).
We performed 6-fold cross-validation to estimate the cross-validated MSE as a function of \(\alpha\).
# ----------------------- Baseball Example (Regression Tree) -----------------------
# Goal: Predict logSalary with a regression tree; choose subtree by 6-fold CV.
# --- Packages --------------------------------------------------------------------
library(ISLR2) # Hitters data
library(dplyr) # data wrangling
library(tree) # tree(), cv.tree(), prune.tree()
# --- Data prep -------------------------------------------------------------------
# 1) Load data
data("Hitters", package = "ISLR2")
# 2) Remove rows with any NA (ISLR2::Hitters has NAs in Salary)
# Note: This yields n = 263 (as in ISL).
Hitters <- na.omit(Hitters) %>%
# 3) Create outcome and keep a subset of predictors
mutate(logSalary = log(Salary)) %>%
select(Hits, RBI, Years, HmRun, logSalary, PutOuts, Walks, AtBat, Assists, Errors)
# --- Train / test split ----------------------------------------------------------
set.seed(100)
n <- nrow(Hitters)
n_train <- ceiling(n / 2) # 132 if n = 263
idx_tr <- sample(seq_len(n), n_train)
train.Hitters <- Hitters[idx_tr, ]
test.Hitters <- Hitters[-idx_tr, ]
# --- Fit a large tree on training data -------------------------------------------
tree_full <- tree(logSalary ~ ., data = train.Hitters)
# Inspect the tree
summary(tree_full)
Regression tree:
tree(formula = logSalary ~ ., data = train.Hitters)
Variables actually used in tree construction:
[1] "Years" "RBI" "PutOuts" "Hits"
Number of terminal nodes: 10
Residual mean deviance: 0.1754 = 21.4 / 122
Distribution of residuals:
Min. 1st Qu. Median Mean 3rd Qu. Max.
-0.9870 -0.2620 -0.0157 0.0000 0.2236 1.7200
# --- 6-fold CV to choose subtree size --------------------------------------------
# cv.tree() returns deviance by number of terminal nodes (size).
# For regression trees, 'dev' is proportional to SSE
cv_fit <- cv.tree(tree_full, K = 6)
# Quick look
print(names(cv_fit)) # "size", "dev", "k"[1] "size" "dev" "k" "method"
$size
[1] 10 9 8 7 6 5 4 3 2 1
$dev
[1] 37.26339 36.88053 36.88053 37.30710 36.66607 38.93577 46.55905
[8] 47.39311 50.84878 100.51312
$k
[1] -Inf 1.225032 1.245748 1.402796 1.559641 2.452100 4.540760
[8] 5.578953 8.429781 49.720232
$method
[1] "deviance"
attr(,"class")
[1] "prune" "tree.sequence"
# Plot CV curve and mark the minimum
plot(cv_fit$size, cv_fit$dev, type = "b",
xlab = "Number of leaves (size)", ylab = "CV deviance",
main = "6-fold CV for cost-complexity pruning")
best_size <- cv_fit$size[which.min(cv_fit$dev)]
abline(v = best_size, lty = 2, col = "red")# --- Prune to the selected size --------------------------------------------------
tree_pruned <- prune.tree(tree_full, best = best_size)
# Plot pruned tree
plot(tree_pruned, type = "uniform")
text(tree_pruned, pretty = 0)# --- Evaluation on test set ------------------------------------------------------
y_hat_test <- predict(tree_pruned, newdata = test.Hitters)
y_test <- test.Hitters$logSalary
# Test MSE
mse_test <- mean((y_hat_test - y_test)^2)
cat(sprintf("Pruned tree size: %d leaves\n", best_size))Pruned tree size: 6 leaves
Test MSE: 0.4198
# --- Compare to the unpruned tree on test set -------------------------
y_hat_full <- predict(tree_full, newdata = test.Hitters)
mse_full <- mean((y_hat_full - y_test)^2)
cat(sprintf("Unpruned tree Test MSE: %.4f\n", mse_full))Unpruned tree Test MSE: 0.4319
# Predicted vs. observed plot
plot(y_hat_test, y_test,
xlab = "Predicted logSalary", ylab = "Observed logSalary",
main = "Test set: Predicted vs Observed")
abline(0, 1, col = "red", lwd = 2)We use classification trees to analyze the Carseats data set.
In these data, Sales is a continuous variable, and so we begin by recoding it as a binary variable called High, which takes on a value of High if the Sales variable exceeds \(8\), and takes on a value of Low otherwise.
# ------------------- Carseats classification tree (with comments) -------------------
# Load data
data("Carseats", package = "ISLR2")
# Create a binary outcome: High = "High" if Sales > 8, else "Low"
# Then drop 'Sales' so it isn't used as a predictor.
Carseats <- Carseats |>
mutate(High = factor(ifelse(Carseats$Sales <= 8, "Low", "High"))) |>
dplyr::select(-Sales)
# --- Train / test split (50/50) ---------------------------------------------------
set.seed(1) # for reproducibility
n <- nrow(Carseats)
n_train <- ceiling(n / 2) # size of the training set
idx_tr <- sample(seq_len(n), n_train)
train.Carseats <- Carseats[idx_tr, ]
test.Carseats <- Carseats[-idx_tr, ]
# --- Fit a classification tree on the training data --------------------------------
# Using the 'tree' package. For classification, the criterion is deviance (entropy).
tree.carseats <- tree(High ~ ., train.Carseats)
# Inspect the fitted tree (splits, node sizes, deviance, misclass error, etc.)
summary(tree.carseats)
Classification tree:
tree(formula = High ~ ., data = train.Carseats)
Variables actually used in tree construction:
[1] "Price" "Population" "US" "CompPrice" "Advertising"
[6] "Income" "ShelveLoc" "Age"
Number of terminal nodes: 20
Residual mean deviance: 0.4549 = 81.89 / 180
Misclassification error rate: 0.105 = 21 / 200
# --- Predict on the test set -------------------------------------------------------
y.test <- test.Carseats$High
# type = "class" returns the predicted class label (majority class in leaf);
# without type="class", predict() would return class probabilities.
y.pred <- predict(tree.carseats, test.Carseats, type = "class")
# Confusion matrix: rows = predicted, columns = actual
table(y.pred, y.test) y.test
y.pred High Low
High 44 35
Low 37 84
# Test accuracy = (true positives + true negatives) / total test cases
sum(diag(table(y.pred, y.test))) / (n - n_train)[1] 0.64
# --- Cross-validated pruning by misclassification error ----------------
set.seed(123) # CV reproducibility
cv.carseats <- cv.tree(tree.carseats, FUN = prune.misclass) # CV error vs. size
cv.carseats # columns: size (leaves), dev (CV misclass error), k (alpha)$size
[1] 20 18 10 8 6 4 2 1
$dev
[1] 69 69 65 62 58 63 79 83
$k
[1] -Inf 0.0 0.5 1.5 2.0 4.0 12.0 19.0
$method
[1] "misclass"
attr(,"class")
[1] "prune" "tree.sequence"
# Choose subtree size with minimum CV error and prune to that size
best_size <- cv.carseats$size[which.min(cv.carseats$dev)]
prune.carseats <- prune.misclass(tree.carseats, best = best_size)
# Summary of the pruned tree
summary(prune.carseats)
Classification tree:
snip.tree(tree = tree.carseats, nodes = c(2L, 6L, 29L, 56L, 15L
))
Variables actually used in tree construction:
[1] "Price" "CompPrice" "Advertising" "ShelveLoc"
Number of terminal nodes: 6
Residual mean deviance: 0.8562 = 166.1 / 194
Misclassification error rate: 0.16 = 32 / 200
# Predictions from the pruned tree on the test set
y.pred <- predict(prune.carseats, test.Carseats, type = "class")
# Confusion matrix and test accuracy for the pruned model
table(y.pred, y.test) y.test
y.pred High Low
High 49 33
Low 32 86
[1] 0.675
Use the Boston housing to model a regression tree
You can load the data from the R package ISLR2
#load the data
data("Boston", package = "ISLR2")
# Load the Boston housing data from ISLR2
data("Boston", package = "ISLR2")
set.seed(101) # for reproducible train/test split
# 50/50 split (rounded up if odd n)
idx_tr_Boston <- sample(1:nrow(Boston), ceiling(nrow(Boston) / 2))
# Fit a *large* regression tree on the training half
tree.boston <- tree(medv ~ ., Boston, subset = idx_tr_Boston)
# Inspect the fitted tree (splits, node counts, RSS, etc.)
summary(tree.boston)
Regression tree:
tree(formula = medv ~ ., data = Boston, subset = idx_tr_Boston)
Variables actually used in tree construction:
[1] "rm" "lstat" "dis" "nox"
Number of terminal nodes: 11
Residual mean deviance: 12.05 = 2916 / 242
Distribution of residuals:
Min. 1st Qu. Median Mean 3rd Qu. Max.
-14.5700 -1.6710 -0.1037 0.0000 1.6370 18.2800
# ---- Cross-validation to choose subtree size ----
cv.boston <- cv.tree(tree.boston) # returns size (leaves), deviance, and k (alpha)
# Index of the minimum CV deviance
which.min(cv.boston$dev)[1] 1
[1] 11
In this case, the most complex tree under consideration is selected by cross-validation.
Tree-based methods are simple and interpretable.
They are flexible: usable for both regression and classification.
However, they often underperform top-performing supervised methods in prediction accuracy.
They can be unstable: small data changes may yield different trees and predictions.
Ensemble methods, bagging, random forests, boosting—grow many trees and combine their predictions.
Combining many trees often improves accuracy dramatically, at the cost of some loss of interpretability.
Bagging (bootstrap aggregating) improves accuracy by reducing variance via averaging many models.
Recall: for \(n\) independent observations \(Z_1,\dots,Z_n\) with variance \(\sigma^2\), the variance of the mean is \(\sigma^2/n\).
We don’t have multiple training sets, so we bootstrap from the single training set.
Generate \(B\) bootstrap samples; fit a model on each to get \(\hat f_{(b)}(x)\), then average: \[ \hat f_{\text{bag}}(x)=\frac{1}{B}\sum_{b=1}^B \hat f_{(b)}(x). \]
Works for regression and classification trees.
Classification: take the majority vote across the \(B\) trees.
Random forests improve on bagging by decorrelating the trees.
As in bagging, fit many trees on bootstrap samples.
At each split, sample \(m\) predictors (from \(p\) total) and choose the best split only among those \(m\).
This random subsetting reduces tree-to-tree correlation → lower variance when averaging.
Typical choices: classification \(m \approx \sqrt{p}\); regression \(m \approx p/3\).
It can use OOB error to tune \(m\) (and other settings) without a validation set.
High by random forest; report OOB error.# ---------------- Random Forest (Carseats) — train on train.Carseats, test on test.Carseats ----------------
library(randomForest)
set.seed(123) # for reproducible trees and OOB estimate
# Fit RF on the training set
# - ntree: number of trees to grow
# - mtry: ~sqrt(p) predictors tried at each split (p = number of features, exclude outcome)
# - importance: store variable importance measures
rf.Carseats <- randomForest(
High ~ .,
data = train.Carseats,
ntree = 500,
mtry = floor(sqrt(ncol(train.Carseats) - 1)), # use train.Carseats to count predictors
importance = TRUE
)
rf.Carseats # prints OOB error (approximate test error)
Call:
randomForest(formula = High ~ ., data = train.Carseats, ntree = 500, mtry = floor(sqrt(ncol(train.Carseats) - 1)), importance = TRUE)
Type of random forest: classification
Number of trees: 500
No. of variables tried at each split: 3
OOB estimate of error rate: 19%
Confusion matrix:
High Low class.error
High 58 25 0.3012048
Low 13 104 0.1111111
# ---- Test-set evaluation ----
# Predicted class labels on the held-out test set
y_pred <- predict(rf.Carseats, newdata = test.Carseats, type = "class")
y_true <- test.Carseats$High
# Confusion matrix
table(Predicted = y_pred, Actual = y_true) Actual
Predicted High Low
High 61 18
Low 20 101
[1] 0.81
logSalary; report OOB MSE.# ---------------- Random Forest (Hitters, regression) — test MSE on test.Hitters ----------------
set.seed(123) # reproducible RF
# Fit RF on the training set
# - For regression, randomForest default mtry ≈ p/3 (p = # predictors)
rf_reg <- randomForest(
logSalary ~ ., data = train.Hitters,
ntree = 500
)
rf_reg # prints OOB MSE (approx. test error)
Call:
randomForest(formula = logSalary ~ ., data = train.Hitters, ntree = 500)
Type of random forest: regression
Number of trees: 500
No. of variables tried at each split: 3
Mean of squared residuals: 0.2606431
% Var explained: 64.73
# ---- Test-set evaluation ----
# Predict logSalary on the held-out test set
y_pred <- predict(rf_reg, newdata = test.Hitters)
y_true <- test.Hitters$logSalary
# Test MSE
mse_test <- mean((y_pred - y_true)^2)
mse_test[1] 0.3261364
The MSE for test data is \(0.3261\), a large improvement of that obtained using an optimally-pruned single tree, \(0.4198\) and unpruned tree, \(0.4319\).
Boosting is an ensemble technique that combines multiple weak learners (typically decision trees) to form a stronger model.
It sequentially corrects mistakes made by previous models, with each subsequent tree trained on the residuals of the previous one.
There are several boosting algorithms: Gradient Boosting Machines (GBM), AdaBoost, XGBoost, LightGBM, and CatBoost.
Boosting (specially XGBoost) often wins over RF on moderate-to-large datasets with many weak additive signals and interactions, especially when you can tune and use early stopping (it also handles sparsity and missing values well).
Prefer Random Forest when data are small or noisy and you need a robust baseline with minimal tuning and reliable OOB error.
Since the focus of our workshop is Machine Learning for Scientific Research, we will skip these models; they are listed here for your information.
Steps of boosting:
Initialization:
Start with an initial prediction. For regression, this could be the mean of the target; for classification, the log-odds of the classes. \[
F_0(x) = \frac{1}{n} \sum_{i=1}^{n} y_i
\]
Compute residuals:
For each observation, compute the residuals (the difference between the true outcome and the predicted outcome). \[
r_i = y_i - F_0(x_i)
\]
Fit a new model on residuals:
\[
h_1(x) = \text{Tree trained on residuals}
\] where \(h_1(x)\) is the prediction from the new model.
Update the predictions:
Add the new model’s predictions to the current predictions (scaled by the learning rate \(\alpha\)). \[
F_1(x)=F_0(x)+\alpha\,h_1(x)
\]
Repeat for further models:
Train each new tree on the residuals of the current model and update predictions. After \(t\) iterations: \[
F_t(x)=F_{t-1}(x)+\alpha\,h_t(x)
\]
Regularization in boosting
A specific form of boosting where new models are fit on the residual errors (gradients) of previous models.
Use pseudo-residuals from a differentiable loss \(\mathcal{L}(y, F(x))\): \[ g_i = \frac{\partial \mathcal{L}(y_i, F(x_i))}{\partial F(x_i)} \] Then fit a base learner (e.g., a tree) using \(g_i\) to get prediction \(h_t(x_i)\). (Commonly, the negative gradient is used as the pseudo-residual.)
In R, use the gbm package for Gradient Boosting Machines.
AdaBoost adjusts the weight of each data point based on whether it was correctly or incorrectly classified by the previous model. Misclassified points receive higher weights; in regression, larger residuals get more emphasis.
Initially, all \(n\) data points have equal weight \(1/n\).
Update weights based on residuals (regression) or misclassification (classification) so the next weak learner focuses on hard cases.
In R, the ada package implements AdaBoost and provides training and prediction functions.
XGBoost is an optimized implementation of gradient boosting and is widely used for classification and regression.
Built-in handling for missing data (learned default directions, no imputation needed).
Uses max depth (pre-pruning; stop when a leaf reaches max depth) and can also apply post-pruning.
Includes L1 (Lasso) and L2 (Ridge) regularization.
Designed for parallel processing, often much faster than traditional GBM.
Key parameters in XGBoost
eta (learning rate)
n_estimators (number of trees/boosting rounds)
max_depth (maximum tree depth)
subsample — fraction of rows per boosting round (helps prevent overfitting)
colsample_bytree — fraction of features per tree (adds randomness, reduces overfitting)
lambda (L2 regularization)
alpha (L1 regularization)
gamma — minimum loss reduction to make a further partition
The xgboost package provides a fast, popular implementation of Extreme Gradient Boosting.
The caret package offers wrappers for AdaBoost, Gradient Boosting, and XGBoost for training and evaluation.
Bagging and random forests provide variable importance to identify key predictors.
We record the total decrease in \(RSS\) (regression) or Gini (classification) due to splits on a predictor, averaged over \(B\) trees. Larger values means more important.
In iml package in R, permutation importance (FeatureImp) ranks features by the drop in performance (e.g., MSE) when a feature is shuffled.
Use both: model-specific (randomForest) and model-agnostic (iml) views often agree but can highlight different aspects.
iml shows algorithm-independent importance/effects; RF importance shows what the forest actually used.
#load packages
library(randomForest)
# Load the Boston housing data from ISLR2
data("Boston", package = "ISLR2")
#set the seed
set.seed(100)
rf.boston <- randomForest(medv ~ ., data = Boston,
mtry = 6, importance = TRUE)
#Extract importance
importance(rf.boston) %IncMSE IncNodePurity
crim 16.570550 2222.8060
zn 3.295833 107.3477
indus 10.394410 1756.9428
chas 2.333964 131.5532
nox 19.773039 2381.7193
rm 44.451021 15412.2843
age 12.718606 848.6475
dis 20.808048 2399.7688
rad 7.103155 228.8922
tax 11.139908 854.1824
ptratio 15.678596 1755.9960
lstat 36.321009 14310.8429
Percentage IncMSE is based upon the mean decrease of accuracy in predictions on the out of bag samples when a given variable is permuted.
IncNodePurity is a measure of the total decrease in node impurity that results from splits over that variable, averaged over all trees.
In the case of regression trees, the node impurity is measured by the training RSS, and for classification trees by the deviance.
The results indicate that across all of the trees considered in the random forest, the wealth of the community (lstat) and the house size (rm) are by far the two most important variables.
iml)Use permutation importance (FeatureImp): shuffle one feature, re-evaluate the model, and measure the drop in performance. Bigger drop, means more important feature.
Loss choice (regression): Use Mean Absolute Error (MAE) ("mae"); MSE ("mse") is another option. For classification, we use ("logLoss").
In R we wrap fitted model with a Predictor (R6 class), then call FeatureImp.
library("iml")
library("ggplot2")
X <- Boston[which(names(Boston) != "medv")]
predictor <- Predictor$new(rf.boston, data = X, y = Boston$medv)
imp <- FeatureImp$new(predictor, loss = "mse")
plot(imp) feature importance.05 importance importance.95 permutation.error
1 lstat 25.593076 28.735468 30.702135 51.290594
2 rm 18.523422 19.210765 20.507692 34.289734
3 dis 3.499887 3.626493 3.968537 6.473010
4 nox 3.285675 3.439364 3.691834 6.138999
5 crim 3.029992 3.068844 3.222615 5.477650
6 ptratio 2.144980 2.300673 2.393115 4.106524
7 tax 1.739776 1.761973 1.773983 3.144985
8 indus 1.711617 1.733345 1.833625 3.093887
9 age 1.673318 1.730406 1.801173 3.088642
10 rad 1.143681 1.161688 1.188848 2.073524
11 chas 1.048845 1.053703 1.058204 1.880779
12 zn 1.043218 1.044404 1.049096 1.864181
imlFeature effects show how predictions change as one feature varies (PDP, ALE, ICE).
PDP (Partial Dependence)
Averages predictions over other features: effect of one feature on the prediction level.
Simple, global view; can be biased with correlated features.
iml::FeatureEffect(..., method = "pdp")
ALE (Accumulated Local Effects)
Averages local prediction changes within bins and integrates; centered to mean 0.
More robust with correlated features; shows relative effect (not absolute level).
iml::FeatureEffect(..., method = "ale")
ICE (Individual Conditional Expectation)
iml::FeatureEffect(..., method = "ice")Quantify how much a feature participates in interactions (beyond its main effect).
interaction strength (Friedman’s H), scaled 0–1
Intuition: measure the fraction of the model’s variance not explained by the feature’s 1D effect alone.
We can also specify a feature and measure all it’s 2-way interactions with all other features:
Fit a small decision tree to the black-box predictions (e.g., random forest) using the original features.
The tree approximates the black box and yields simple rules (terminal nodes).
Control interpretability with maxdepth (shallower → simpler).
Note: explains the model (black box), not necessarily the true data
Reference: “Introduction to iml: Interpretable Machine Learning in R” (CRAN vignette).
Causal inference is the process of drawing conclusions about causal relationships from data.
The primary goal is to determine cause-and-effect relationships rather than correlations or associations.
The causal effect of a treatment is difference between the outcome with treatment and the outcome without treatment.
Randomized Controlled Trials
Participants are randomly assigned to either the treatment or control group.
Example: Participants are randomly assigned to receive a new drug or a placebo. By comparing the outcomes between the two groups, we can estimate the causal effect of the drug.
Observational Studies
While randomized controlled trials are ideal, may be infeasible or unethical; we rely on observational data.
Observational data are prone to confounding, which occurs when a third variable affects both the treatment and the outcome.
Example: studying the effect of Ozempic on weight loss; if adults with type 2 diabetes are likelier to use the drug and tend to have different weight-loss patterns, naive comparisons are biased.
There are several types of treatment effects depending on the scope and the group of people being considered:
Average Treatment Effect (ATE)
\[\text{ATE} = \text{E}(Y(1) - Y(0))\]
\(Y(1)\) is the outcome under treatment.
\(Y(0)\) is the outcome under no treatment.
\(\text{E}\) denotes the expectation (average) over the population.
Average Treatment Effect on the Treated (ATT)
\[\text{ATT} = \text{E}[Y(1) - Y(0) | W_i = 1]\]
\(W_i = 1\) denote treated observations.
This measures the average effect of the treatment on those individuals who actually received the treatment.
Average Treatment Effect on the Untreated (ATU):
\[\text{ATU} = \text{E}[Y(1) - Y(0) | W_i = 0]\]
This measures the average effect of the treatment on those individuals who did not receive the treatment.
Homogeneous Treatment Effect:
Suppose we have a partially linear regression model (PLR):
\[Y = \tau W +g(X) + \epsilon\]
\(Y\): outcome
\(W\): treatment (often binary but can be continuous)
\(\tau\): Homogeneous treatment effect. It is constant across individuals.
\(g(X)\): nuisance function (possibly nonlinear effect of X)
\(\epsilon\): error
Heterogeneous Treatment Effects (HTE)
When treatment effects vary across individuals:
\[Y = \tau(X) W +g(X) + \epsilon\] \(\tau(X)\): conditional treatment effect (CATE), varies with covariates
If the treatment is binary (0, 1) then we define:
1- Individual Treatment Effect (ITE)
\[\text{ITE}_i = Y_i(1) - Y_i(0)\] 2- Conditional Average Treatment Effect (CATE)
\[\text{CATE} = \text{E}[Y(1) - Y(0) | X]\]
Robinson (1988) proposed a clever way to estimate \(\tau\) without needing to perfectly specify \(g(x)\). The steps are as follow:
Step 1: Estimate nuisance functions
\[m(X) = E(Y | X) \text{ (baseline outcome)} \\ r(X) = E(W | X) \text{ (propensity score for binary W)}\]
Step 2: Residualize:
\[\widetilde{Y} = Y - m(X), \\ \widetilde{W} = W - r(X)\]
Step 3: Estimate treatment effect
\(\widetilde{Y} = \tau \widetilde{W} +\epsilon\)
If the treatment effects is Homogeneous, we can use OLS:
\[\hat{\tau} = \dfrac{\sum_i^n \widetilde{W_i} \widetilde{Y_i}}{\sum_i^n\widetilde{W_i}^2}\]
In the case of heterogeneous treatment effects, for each \(X\) we estimate \(\tau(X)\) by minimizing
\[\arg \min \sum_i \left[\widetilde{Y_i} - \tau(X_i) \widetilde{W_i}\right]^2\] This Generalizes Robinson’s idea is called R-learner (Nie & Wager, 2021).
Causal Forest is an extension of Random Forest to estimate the heterogeneous treatment effects (HTEs).
The goal is to estimate Individual Treatment Effects based on their covariates
In R, Causal Forest is implemented in the R package grf
grf stands for Generalized Random Forests and it generalizes the idea of Random Forest beyond just prediction of outcomes to a whole family of statistical estimation problems, including:
Treatment effect estimation (CATE, ATE, ATT, ATC)
Quantile regression
Instrumental variables regression
Random Forest: Built to predict an outcome \(Y\) from features \(X\).
Causal Forest Built to predict a treatment effect \[\tau(x) = \text{E}[Y(1) - Y(0) | X=x]\].
Causal forests introduce several modifications compared to random forests:
Splitting rule in tree:
Honest Estimation:
Out-of-Bag (OOB) Predictions:
Bootstrap Sampling: Take a bootstrap sample of the data for tree \(b=1,\cdots,B\).
Honest Splitting: divide the sample into two halves:
Estimate Nuisance Functions (Robinson Transformation)
\[\hat{m}(X) = E(Y | X) \\ \hat{r}(X) = E(W | X) \text{ (propensity score})\]
Evaluate candidate splits based on treatment effect heterogeneity in the splitting half.
- Use residualized outcomes
\[\widetilde{Y_i} = Y_i - \hat{m}(X_i), \\ \widetilde{Wi} = W_i - \hat{r}(X_i)\]
For a candidate split (j,s). Let’s split to left L, right R.
Estimate the local effect in each child via residual-on-residual slope (no intercept):
\[ \hat\tau_m = \frac{\sum_{i\in R_m} \tilde W_i\,\tilde Y_i}{\sum_{i\in R_m} \tilde W_i^{\,2}}, \quad m\in\{L,R\}. \] Choose the split that maximizing the treatment effect heterogeneity score.
\[ \text{Score}(j,s)=\frac{N_L\,N_R}{N}\,\big(\hat\tau_L-\hat\tau_R\big)^2. \]
Choose \((j^*,s^*)=\arg\max_{(j,s)} \text{Score}(j,s)\) and recurse on the children.
Tree Growth: Repeat recursively until stopping criteria are met (e.g., minimum node size).
Treatment Effect Estimation: For each leaf node, use the estimation half to compute:
\[\hat{\tau}_{leaf} = \dfrac{\sum_i \widetilde{W_i} \widetilde{Y_i}}{\sum_i\widetilde{W_i}^2}\]
Out-of-Bag Prediction: Predict CATE for each observation (in the data) using trees where the observation was not in the bootstrap sample (OOB).
Aggregate Across Trees: Average CATE estimates across all trees for the final prediction (for new data).
Evaluate the effect of a healthcare program (treatment) on blood pressure reduction.
Covariates (patient characteristics, 10 total):
Treatment assignment (W):
Not randomized: higher probability for patients with BMI > 27 or smokers
Outcome (Y):
Reduction in systolic blood pressure (mmHg)
Influenced by:
Baseline covariates (e.g., age, smoking)
library(grf)
#----------------------------
# Example: Healthcare program effect on blood pressure
#----------------------------
set.seed(123)
n <- 2000
p <- 10
# Baseline covariates (pretend these are patient characteristics)
patient_covariates <- data.frame(
age = rnorm(n, 50, 12),
bmi = rnorm(n, 27, 5),
smoker = rbinom(n, 1, 0.3),
cholesterol= rnorm(n, 200, 30),
exercise = rpois(n, 2),
income = rnorm(n, 60000, 15000),
diabetes = rbinom(n, 1, 0.2),
region = rbinom(n, 1, 0.5),
stress = rnorm(n, 5, 2),
alcohol = rpois(n, 3)
)
X <- as.matrix(patient_covariates)
# Treatment assignment (healthcare program enrollment depends on BMI & smoking)
W <- rbinom(n, 1, 0.4 + 0.2 * (X[, "bmi"] > 27))
# Outcome: reduction in systolic blood pressure
# baseline risk factors + heterogeneity in treatment effect
Y <- pmax(X[, "bmi"] - 25, 0) * W +
(-0.5)*X[, "smoker"] +
(-0.02)*X[, "age"] +
rnorm(n)
#----------------------------
# True treatment effect
#----------------------------
# tau(X) = max(bmi-25, 0)
# so higher-BMI patients benefit more
true_tau <- pmax(X[, "bmi"] - 25, 0)
#----------------------------
# Train a causal forest
#----------------------------
tau.forest <- causal_forest(X, Y, W)
# OOB predicted treatment effects
tau.hat.oob <- predict(tau.forest)$predictions
hist(tau.hat.oob, main = "Estimated Individual Treatment Effects",
xlab = "Predicted tau", col = "skyblue")
Best linear fit using forest predictions (on held-out data)
as well as the mean forest prediction as regressors, along
with one-sided heteroskedasticity-robust (HC3) SEs:
Estimate Std. Error t value Pr(>t)
mean.forest.prediction 0.998764 0.015789 63.256 < 2.2e-16 ***
differential.forest.prediction 1.016186 0.016495 61.605 < 2.2e-16 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
# Test set to visualize effect curve w.r.t BMI
bmi.test <- seq(15, 40, length.out = 100)
X.test <- matrix(0, 100, p)
colnames(X.test) <- colnames(X)
X.test[, "bmi"] <- bmi.test
mean(Y[W==1]) - mean(Y[W==0])[1] 3.855895
tau.hat.test <- predict(tau.forest, X.test)$predictions
plot(bmi.test, tau.hat.test, type = "l", lwd = 2,
ylab = "Estimated Treatment Effect",
xlab = "BMI")
lines(bmi.test, pmax(bmi.test - 25, 0), col = 2, lty = 2)
Call:
lm(formula = Y ~ ., data = dat)
Coefficients:
(Intercept) W age bmi smoker cholesterol
-8.798e+00 3.272e+00 -2.333e-02 3.398e-01 -3.077e-01 7.369e-04
exercise income diabetes region stress alcohol
-1.126e-02 -4.580e-06 1.627e-01 -8.536e-02 3.479e-03 3.293e-02
#----------------------------
# Average treatment effects
#----------------------------
average_treatment_effect(tau.forest, target.sample = "all") # ATE estimate std.err
3.14952026 0.09200541
estimate std.err
3.8421107 0.1217411
estimate std.err
2.482822 0.112696
#----------------------------
# Confidence intervals with more trees
#----------------------------
tau.forest <- causal_forest(X, Y, W, num.trees = 4000)
tau.hat <- predict(tau.forest, X.test, estimate.variance = TRUE)
sigma.hat <- sqrt(tau.hat$variance.estimates)
plot(bmi.test, tau.hat$predictions, type = "l", lwd = 2,
ylim = range(c(tau.hat$predictions + 1.96 * sigma.hat,
tau.hat$predictions - 1.96 * sigma.hat, 0, 8)),
xlab = "BMI", ylab = "Treatment Effect")
lines(bmi.test, tau.hat$predictions + 1.96 * sigma.hat, lty = 2)
lines(bmi.test, tau.hat$predictions - 1.96 * sigma.hat, lty = 2)
lines(bmi.test, pmax(bmi.test - 25, 0), col = 2, lwd = 2)
legend("topleft", legend = c("Estimated tau", "95% CI", "True tau"),
col = c(1, 1, 2), lty = c(1, 2, 1), lwd = c(2, 1, 2))
Best linear fit using forest predictions (on held-out data)
as well as the mean forest prediction as regressors, along
with one-sided heteroskedasticity-robust (HC3) SEs:
Estimate Std. Error t value Pr(>t)
mean.forest.prediction 0.997629 0.015741 63.379 < 2.2e-16 ***
differential.forest.prediction 1.017300 0.016635 61.156 < 2.2e-16 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Athey, S., & Wager, S. (2019). Estimating treatment effects with Causal Forests: An application, Observational Studies, 5(2), 37-51.
Breiman, L. (1996). Bagging predictors. Machine Learning, 24, 123–140.
Breiman, L. (2001). Random forests. Machine Learning, 45, 5–32.
Breiman, L., Friedman, J. H., Olshen, R. A., & Stone, C. J. (1984). Classification and Regression Trees. Wadsworth.
James, G., Witten, D., Hastie, T., & Tibshirani, R. (2021). An Introduction to Statistical Learning with Applications in R (2nd ed.). Springer.
Molnar, C. (2022). Interpretable Machine Learning: A Guide for Making Black Box Models Explainable (2nd ed.).
Wager, S., & Athey, S. (2018). Estimation and inference of heterogeneous treatment effects using random forests. Annals of Statistics, 46(2), 674–700. https://doi.org/10.1214/18-AOS1709