7. Generalized Random Forest#
Developed by Susan Athey, Julie Tibshirani and Stefan Wager, Generalized Random Forest [8] aims to give the solution to a set of local moment equations:
where \(\tau(s)\) is the parameter we care about and \(\nu(s)\) is an optional nuisance parameter. In the problem of Heterogeneous Treatment Effect Evaluation, our parameter of interest \(\tau(s)=\xi\cdot \beta(s)\) is identified by
The induced estimator \(\hat{\tau}(s)\) for \(\tau(s)\) can thus be solved by
where \(\bar{A}_\alpha=\sum \alpha_i(s)A_i\) and \(\bar{R}_\alpha=\sum \alpha_i(s)R_i\), and we write \(v^{\otimes 2}=vv^T\).
Notice that this formula is just a weighted version of R-learner introduced above. However, instead of using ordinary kernel weighting functions that are prone to a strong curse of dimensionality, GRF uses an adaptive weighting function \(\alpha_i(s)\) derived from a forest designed to express heterogeneity in the specified quantity of interest.
To be more specific, in order to obtain \(\alpha_i(s)\), GRF first grows a set of \(B\) trees indexed by \(1,\dots,B\). Then for each such tree, define \(L_b(s)\) as the set of training samples falling in the same ``leaf” as x. The weights \(\alpha_i(s)\) then capture the frequency with which the \(i\)-th training example falls into the same leaf as \(s\):
To sum up, GRF aims to leverage the splitting result of a series of trees to decide the ``localized” weight for HTE estimation at each point \(x_0\). Compared with kernel functions, we may expect tree-based weights to be more flexible and better performed in real settings.
# import related packages
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt;
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
from causaldm._util_causaldm import *
from causaldm.learners.CEL.Single_Stage import _env_getdata_CEL
from causaldm.learners.CEL.Single_Stage.LpRlearner import LpRlearner
MovieLens Data#
# Get the MovieLens data
MovieLens_CEL = _env_getdata_CEL.get_movielens_CEL()
MovieLens_CEL.pop(MovieLens_CEL.columns[0])
MovieLens_CEL
user_id | movie_id | rating | age | Comedy | Drama | Action | Thriller | Sci-Fi | gender_M | occupation_academic/educator | occupation_college/grad student | occupation_executive/managerial | occupation_other | occupation_technician/engineer | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 48.0 | 1193.0 | 4.0 | 25.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 |
1 | 48.0 | 919.0 | 4.0 | 25.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 |
2 | 48.0 | 527.0 | 5.0 | 25.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 |
3 | 48.0 | 1721.0 | 4.0 | 25.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 |
4 | 48.0 | 150.0 | 4.0 | 25.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
65637 | 5878.0 | 3300.0 | 2.0 | 25.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
65638 | 5878.0 | 1391.0 | 1.0 | 25.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
65639 | 5878.0 | 185.0 | 4.0 | 25.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
65640 | 5878.0 | 2232.0 | 1.0 | 25.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
65641 | 5878.0 | 426.0 | 3.0 | 25.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
65642 rows × 15 columns
n = len(MovieLens_CEL)
userinfo_index = np.array([3,5,6,7,8,9,10])
SandA = MovieLens_CEL.iloc[:, np.array([3,4,5,6,7,8,9,10])]
The generalized random forest (GRF) approach has been implemented in package grf for R and C++, and econml in python. Here we implement the package of econml for a simple illustration.
# import the package for Causal Random Forest
! pip install econml
Requirement already satisfied: econml in d:\anaconda3\lib\site-packages (0.15.0)
Requirement already satisfied: sparse in d:\anaconda3\lib\site-packages (from econml) (0.15.1)
Requirement already satisfied: lightgbm in d:\anaconda3\lib\site-packages (from econml) (4.1.0)
Requirement already satisfied: scipy>1.4.0 in d:\anaconda3\lib\site-packages (from econml) (1.7.3)
Requirement already satisfied: joblib>=0.13.0 in d:\anaconda3\lib\site-packages (from econml) (1.1.0)
Requirement already satisfied: statsmodels>=0.10 in d:\anaconda3\lib\site-packages (from econml) (0.13.2)
Requirement already satisfied: pandas>1.0 in d:\anaconda3\lib\site-packages (from econml) (1.4.2)
Requirement already satisfied: numpy in d:\anaconda3\lib\site-packages (from econml) (1.21.5)
Requirement already satisfied: scikit-learn<1.5,>=1.0 in d:\anaconda3\lib\site-packages (from econml) (1.0.2)
Requirement already satisfied: shap<0.44.0,>=0.38.1 in d:\anaconda3\lib\site-packages (from econml) (0.43.0)
Requirement already satisfied: python-dateutil>=2.8.1 in d:\anaconda3\lib\site-packages (from pandas>1.0->econml) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in d:\anaconda3\lib\site-packages (from pandas>1.0->econml) (2021.3)
Requirement already satisfied: six>=1.5 in d:\anaconda3\lib\site-packages (from python-dateutil>=2.8.1->pandas>1.0->econml) (1.16.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in d:\anaconda3\lib\site-packages (from scikit-learn<1.5,>=1.0->econml) (2.2.0)
Requirement already satisfied: numba in d:\anaconda3\lib\site-packages (from shap<0.44.0,>=0.38.1->econml) (0.55.1)
Requirement already satisfied: packaging>20.9 in d:\anaconda3\lib\site-packages (from shap<0.44.0,>=0.38.1->econml) (21.3)
Requirement already satisfied: slicer==0.0.7 in d:\anaconda3\lib\site-packages (from shap<0.44.0,>=0.38.1->econml) (0.0.7)
Requirement already satisfied: tqdm>=4.27.0 in d:\anaconda3\lib\site-packages (from shap<0.44.0,>=0.38.1->econml) (4.64.0)
Requirement already satisfied: cloudpickle in d:\anaconda3\lib\site-packages (from shap<0.44.0,>=0.38.1->econml) (2.0.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in d:\anaconda3\lib\site-packages (from packaging>20.9->shap<0.44.0,>=0.38.1->econml) (3.0.4)
Requirement already satisfied: patsy>=0.5.2 in d:\anaconda3\lib\site-packages (from statsmodels>=0.10->econml) (0.5.2)
Requirement already satisfied: colorama in d:\anaconda3\lib\site-packages (from tqdm>=4.27.0->shap<0.44.0,>=0.38.1->econml) (0.4.6)
Requirement already satisfied: llvmlite<0.39,>=0.38.0rc1 in d:\anaconda3\lib\site-packages (from numba->shap<0.44.0,>=0.38.1->econml) (0.38.0)
Requirement already satisfied: setuptools in c:\users\vivia\appdata\roaming\python\python39\site-packages (from numba->shap<0.44.0,>=0.38.1->econml) (68.2.2)
# A demo code of Causal Random Forest
from econml.grf import CausalForest, CausalIVForest, RegressionForest
from econml.dml import CausalForestDML
est = CausalForest(criterion='het', n_estimators=400, min_samples_leaf=5, max_depth=None,
min_var_fraction_leaf=None, min_var_leaf_on_val=True,
min_impurity_decrease = 0.0, max_samples=0.45, min_balancedness_tol=.45,
warm_start=False, inference=True, fit_intercept=True, subforest_size=4,
honest=True, verbose=0, n_jobs=-1, random_state=1235)
est.fit(MovieLens_CEL.iloc[:,userinfo_index], MovieLens_CEL['Drama'], MovieLens_CEL['rating'])
HTE_GRF = est.predict(MovieLens_CEL.iloc[:,userinfo_index], interval=False, alpha=0.05)
HTE_GRF = HTE_GRF.flatten()
Let’s focus on the estimated HTEs for three randomly chosen users:
print("Generalized Random Forest: ",HTE_GRF[np.array([0,300,900])])
Generalized Random Forest: [0.3588 0.3588 1.7786]
ATE_GRF = np.sum(HTE_GRF)/n
print("Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by",round(ATE_GRF,4), "out of 5 points.")
Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by 1.0468 out of 5 points.
Conclusion: Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by 0.358 out of 5 points.
References#
Susan Athey, Julie Tibshirani, and Stefan Wager. Generalized random forests. The Annals of Statistics, 47(2):1148–1178, 2019.