7. Generalized Random Forest#

Developed by Susan Athey, Julie Tibshirani and Stefan Wager, Generalized Random Forest [8] aims to give the solution to a set of local moment equations:

(34)#\[\begin{equation} \mathbb{E}\big[\psi_{\tau(s),\nu(s)}(O_i)\big| S_i=s\big]=0, \end{equation}\]

where \(\tau(s)\) is the parameter we care about and \(\nu(s)\) is an optional nuisance parameter. In the problem of Heterogeneous Treatment Effect Evaluation, our parameter of interest \(\tau(s)=\xi\cdot \beta(s)\) is identified by

(35)#\[\begin{equation} \psi_{\beta(s),\nu(s)}(R_i,A_i)=(R_i-\beta(s)\cdot A_i-c(s))(1 \quad A_i^T)^T. \end{equation}\]

The induced estimator \(\hat{\tau}(s)\) for \(\tau(s)\) can thus be solved by

(36)#\[\begin{equation} \hat{\tau}(s)=\xi^T\left(\sum_{i=1}^n \alpha_i(s)\big(A_i-\bar{A}_\alpha\big)^{\otimes 2}\right)^{-1}\sum_{i=1}^n \alpha_i(s)\big(A_i-\bar{A}_\alpha\big)\big(R_i-\bar{R}_\alpha\big), \end{equation}\]

where \(\bar{A}_\alpha=\sum \alpha_i(s)A_i\) and \(\bar{R}_\alpha=\sum \alpha_i(s)R_i\), and we write \(v^{\otimes 2}=vv^T\).

Notice that this formula is just a weighted version of R-learner introduced above. However, instead of using ordinary kernel weighting functions that are prone to a strong curse of dimensionality, GRF uses an adaptive weighting function \(\alpha_i(s)\) derived from a forest designed to express heterogeneity in the specified quantity of interest.

To be more specific, in order to obtain \(\alpha_i(s)\), GRF first grows a set of \(B\) trees indexed by \(1,\dots,B\). Then for each such tree, define \(L_b(s)\) as the set of training samples falling in the same ``leaf” as x. The weights \(\alpha_i(s)\) then capture the frequency with which the \(i\)-th training example falls into the same leaf as \(s\):

(37)#\[\begin{equation} \alpha_{bi}(s)=\frac{\boldsymbol{1}\big(\{S_i\in L_b(s)\}\big)}{\big|L_b(s)\big|},\quad \alpha_i(s)=\frac{1}{B}\sum_{b=1}^B \alpha_{bi}(s). \end{equation}\]

To sum up, GRF aims to leverage the splitting result of a series of trees to decide the ``localized” weight for HTE estimation at each point \(x_0\). Compared with kernel functions, we may expect tree-based weights to be more flexible and better performed in real settings.

# import related packages
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt;
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression 
from causaldm._util_causaldm import *
from causaldm.learners.CEL.Single_Stage import _env_getdata_CEL
from causaldm.learners.CEL.Single_Stage.LpRlearner import LpRlearner

MovieLens Data#

# Get the MovieLens data
MovieLens_CEL = _env_getdata_CEL.get_movielens_CEL()
MovieLens_CEL.pop(MovieLens_CEL.columns[0])
MovieLens_CEL
user_id movie_id rating age Comedy Drama Action Thriller Sci-Fi gender_M occupation_academic/educator occupation_college/grad student occupation_executive/managerial occupation_other occupation_technician/engineer
0 48.0 1193.0 4.0 25.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0
1 48.0 919.0 4.0 25.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0
2 48.0 527.0 5.0 25.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0
3 48.0 1721.0 4.0 25.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0
4 48.0 150.0 4.0 25.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
65637 5878.0 3300.0 2.0 25.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
65638 5878.0 1391.0 1.0 25.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
65639 5878.0 185.0 4.0 25.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
65640 5878.0 2232.0 1.0 25.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
65641 5878.0 426.0 3.0 25.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0

65642 rows × 15 columns

n = len(MovieLens_CEL)
userinfo_index = np.array([3,5,6,7,8,9,10])
SandA = MovieLens_CEL.iloc[:, np.array([3,4,5,6,7,8,9,10])]

The generalized random forest (GRF) approach has been implemented in package grf for R and C++, and econml in python. Here we implement the package of econml for a simple illustration.

# import the package for Causal Random Forest
! pip install econml
Requirement already satisfied: econml in d:\anaconda3\lib\site-packages (0.15.0)
Requirement already satisfied: sparse in d:\anaconda3\lib\site-packages (from econml) (0.15.1)
Requirement already satisfied: lightgbm in d:\anaconda3\lib\site-packages (from econml) (4.1.0)
Requirement already satisfied: scipy>1.4.0 in d:\anaconda3\lib\site-packages (from econml) (1.7.3)
Requirement already satisfied: joblib>=0.13.0 in d:\anaconda3\lib\site-packages (from econml) (1.1.0)
Requirement already satisfied: statsmodels>=0.10 in d:\anaconda3\lib\site-packages (from econml) (0.13.2)
Requirement already satisfied: pandas>1.0 in d:\anaconda3\lib\site-packages (from econml) (1.4.2)
Requirement already satisfied: numpy in d:\anaconda3\lib\site-packages (from econml) (1.21.5)
Requirement already satisfied: scikit-learn<1.5,>=1.0 in d:\anaconda3\lib\site-packages (from econml) (1.0.2)
Requirement already satisfied: shap<0.44.0,>=0.38.1 in d:\anaconda3\lib\site-packages (from econml) (0.43.0)
Requirement already satisfied: python-dateutil>=2.8.1 in d:\anaconda3\lib\site-packages (from pandas>1.0->econml) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in d:\anaconda3\lib\site-packages (from pandas>1.0->econml) (2021.3)
Requirement already satisfied: six>=1.5 in d:\anaconda3\lib\site-packages (from python-dateutil>=2.8.1->pandas>1.0->econml) (1.16.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in d:\anaconda3\lib\site-packages (from scikit-learn<1.5,>=1.0->econml) (2.2.0)
Requirement already satisfied: numba in d:\anaconda3\lib\site-packages (from shap<0.44.0,>=0.38.1->econml) (0.55.1)
Requirement already satisfied: packaging>20.9 in d:\anaconda3\lib\site-packages (from shap<0.44.0,>=0.38.1->econml) (21.3)
Requirement already satisfied: slicer==0.0.7 in d:\anaconda3\lib\site-packages (from shap<0.44.0,>=0.38.1->econml) (0.0.7)
Requirement already satisfied: tqdm>=4.27.0 in d:\anaconda3\lib\site-packages (from shap<0.44.0,>=0.38.1->econml) (4.64.0)
Requirement already satisfied: cloudpickle in d:\anaconda3\lib\site-packages (from shap<0.44.0,>=0.38.1->econml) (2.0.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in d:\anaconda3\lib\site-packages (from packaging>20.9->shap<0.44.0,>=0.38.1->econml) (3.0.4)
Requirement already satisfied: patsy>=0.5.2 in d:\anaconda3\lib\site-packages (from statsmodels>=0.10->econml) (0.5.2)
Requirement already satisfied: colorama in d:\anaconda3\lib\site-packages (from tqdm>=4.27.0->shap<0.44.0,>=0.38.1->econml) (0.4.6)
Requirement already satisfied: llvmlite<0.39,>=0.38.0rc1 in d:\anaconda3\lib\site-packages (from numba->shap<0.44.0,>=0.38.1->econml) (0.38.0)
Requirement already satisfied: setuptools in c:\users\vivia\appdata\roaming\python\python39\site-packages (from numba->shap<0.44.0,>=0.38.1->econml) (68.2.2)
# A demo code of Causal Random Forest
from econml.grf import CausalForest, CausalIVForest, RegressionForest
from econml.dml import CausalForestDML
est = CausalForest(criterion='het', n_estimators=400, min_samples_leaf=5, max_depth=None,
                    min_var_fraction_leaf=None, min_var_leaf_on_val=True,
                    min_impurity_decrease = 0.0, max_samples=0.45, min_balancedness_tol=.45,
                    warm_start=False, inference=True, fit_intercept=True, subforest_size=4,
                    honest=True, verbose=0, n_jobs=-1, random_state=1235)


est.fit(MovieLens_CEL.iloc[:,userinfo_index], MovieLens_CEL['Drama'], MovieLens_CEL['rating'])

HTE_GRF = est.predict(MovieLens_CEL.iloc[:,userinfo_index], interval=False, alpha=0.05)
HTE_GRF = HTE_GRF.flatten()

Let’s focus on the estimated HTEs for three randomly chosen users:

print("Generalized Random Forest:  ",HTE_GRF[np.array([0,300,900])])
Generalized Random Forest:   [0.3588 0.3588 1.7786]
ATE_GRF = np.sum(HTE_GRF)/n
print("Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by",round(ATE_GRF,4), "out of 5 points.")
Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by 1.0468 out of 5 points.

Conclusion: Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by 0.358 out of 5 points.

References#

  1. Susan Athey, Julie Tibshirani, and Stefan Wager. Generalized random forests. The Annals of Statistics, 47(2):1148–1178, 2019.