MIMIC III (Single-Stage)#
In this notebook, we conducted analysis on the MIMIC III data with a single stage. We first analyzed the mediation effect and then evaluate the policy of interest and calculated the optimal policy. As informed by the causal structure learning, here we consider Glucose and PaO2_FiO2 as confounders/states, IV_Input as the action, SOFA as the mediator.
import pandas as pd
import numpy as np
single_data = pd.read_csv('mimic3_single_stage.csv')
single_data.iloc[np.where(single_data['IV Input']<1)[0],3]=0 # change the discrete action to binary
single_data.iloc[np.where(single_data['IV Input']>=1)[0],3]=1 # change the discrete action to binary
single_data.iloc[np.where(single_data['Died within 48H']==-1)[0],-1]=0 # change the discrete action to binary
single_data.head(6)
icustayid | Glucose | PaO2 | PaO2_FiO2 | IV Input | SOFA | Died within 48H | |
---|---|---|---|---|---|---|---|
0 | 1006 | 152.000000 | 100.200000 | 1.0 | 2.800000 | 7.600000 | 0.0 |
1 | 1204 | 138.794872 | 127.782051 | 1.0 | 1.153846 | 6.153846 | 1.0 |
2 | 4132 | 129.364286 | 123.956461 | 1.0 | 3.000000 | 4.600000 | 0.0 |
3 | 4201 | 145.580087 | 118.083333 | 1.0 | 1.363636 | 5.818182 | 1.0 |
4 | 5170 | 174.525000 | 147.350198 | 1.0 | 2.437500 | 4.125000 | 1.0 |
5 | 6504 | 106.081169 | 88.836364 | 0.0 | 0.363636 | 5.090909 | 1.0 |
single_data.shape
(57, 7)
state = np.array(single_data[['Glucose','PaO2_FiO2']])
action = np.array(single_data[['IV Input']])
mediator = np.array(single_data[['SOFA']])
reward = np.array(single_data[['Died within 48H']])
single_dataset = {'state':state,'action':action,'mediator':mediator,'reward':reward}
CEL: Single-Stage Mediation Analysis#
Under the single-stage setting, we are interested in analyzing the treatment effect on the final outcome Died_within_48H observed at the end of the study by comparing the target treatment regime that provides IV input for all patients and the control treatment regime that does not provide any treatment. Using the direct estimator proposed in [1], IPW estimator proposed in [2], and robust estimator proposed in [3], we examine the natural direct and indirect effects of the target treatment regime based on observational data. With the code in the following blocks, the estimated effect components are summarized in the following:
NDE |
NIE |
TE |
|
---|---|---|---|
Direct Estimator |
-.2132 |
.0030 |
-.2103 |
IPW |
-.2332 |
0 |
-.2332 |
Robust |
-.2274 |
-.0163 |
-.2438 |
Specifically, when compared to no treatment, always giving IV input has a negative impact on the survival rate, among which the effect directly from actions to the final outcome dominates.
from causaldm.learners.CEL.MA.ME_Single import ME_Single
# Control Policy
def control_policy(state = None, dim_state=None, action=None, get_a = False):
if get_a:
action_value = np.array([0])
else:
state = np.copy(state).reshape(-1,dim_state)
NT = state.shape[0]
if action is None:
action_value = np.array([0]*NT)
else:
action = np.copy(action).flatten()
if len(action) == 1 and NT>1:
action = action * np.ones(NT)
action_value = 1-action
return action_value
def target_policy(state, dim_state = 1, action=None):
state = np.copy(state).reshape((-1, dim_state))
NT = state.shape[0]
pa = 1 * np.ones(NT)
if action is None:
if NT == 1:
pa = pa[0]
prob_arr = np.array([1-pa, pa])
action_value = np.random.choice([0, 1], 1, p=prob_arr)
else:
raise ValueError('No random for matrix input')
else:
action = np.copy(action).flatten()
action_value = pa * action + (1-pa) * (1-action)
return action_value
problearner_parameters = {"splitter":["best","random"], "max_depth" : range(1,50)},
Direct_est = ME_Single(single_dataset, r_model = 'OLS',
problearner_parameters = problearner_parameters,
truncate = 50,
target_policy=target_policy, control_policy = control_policy,
dim_state = 2, dim_mediator = 1,
MCMC = 50,
nature_decomp = True,
seed = 10,
method = 'Direct')
Direct_est.estimate_DE_ME()
Direct_est.est_DE, Direct_est.est_ME, Direct_est.est_TE,
D:\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py:372: FitFailedWarning:
490 fits failed out of a total of 490.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.
Below are more details about the failures:
--------------------------------------------------------------------------------
490 fits failed with the following error:
Traceback (most recent call last):
File "D:\anaconda3\lib\site-packages\sklearn\model_selection\_validation.py", line 680, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "D:\anaconda3\lib\site-packages\sklearn\tree\_classes.py", line 937, in fit
super().fit(
File "D:\anaconda3\lib\site-packages\sklearn\tree\_classes.py", line 203, in fit
check_classification_targets(y)
File "D:\anaconda3\lib\site-packages\sklearn\utils\multiclass.py", line 197, in check_classification_targets
raise ValueError("Unknown label type: %r" % y_type)
ValueError: Unknown label type: 'continuous'
warnings.warn(some_fits_failed_message, FitFailedWarning)
D:\anaconda3\lib\site-packages\sklearn\model_selection\_search.py:969: UserWarning: One or more of the test scores are non-finite: [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan]
warnings.warn(
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [6], in <cell line: 2>()
1 problearner_parameters = {"splitter":["best","random"], "max_depth" : range(1,50)},
----> 2 Direct_est = ME_Single(single_dataset, r_model = 'OLS',
3 problearner_parameters = problearner_parameters,
4 truncate = 50,
5 target_policy=target_policy, control_policy = control_policy,
6 dim_state = 2, dim_mediator = 1,
7 MCMC = 50,
8 nature_decomp = True,
9 seed = 10,
10 method = 'Direct')
12 Direct_est.estimate_DE_ME()
13 Direct_est.est_DE, Direct_est.est_ME, Direct_est.est_TE,
File D:\anaconda3\lib\site-packages\causaldm\learners\CEL\MA\ME_Single.py:64, in ME_Single.__init__(self, dataset, r_model, problearner_parameters, truncate, target_policy, control_policy, dim_state, dim_mediator, MCMC, seed, nature_decomp, method)
60 self.rewardlearner.train()
62 self.palearner = PALearner(dataset, problearner_parameters, seed, dim_state = dim_state,
63 dim_mediator = dim_mediator)
---> 64 self.palearner.train()
66 self.pie_A = self.target_policy(self.state, self.dim_state, action = self.action)
67 self.I_A = self.control_policy(self.state, self.dim_state, self.action)
File D:\anaconda3\lib\site-packages\causaldm\learners\CEL\MA\probLearner.py:182, in PALearner.train(self)
179 y = action
181 regressor = GridSearchCV(DecisionTreeClassifier(random_state=self.seed), self.parameters, n_jobs=-1)
--> 182 regressor.fit(X=X, y=y)
183 best_params = regressor.best_params_
184 #print('action', best_params)
File D:\anaconda3\lib\site-packages\sklearn\model_selection\_search.py:926, in BaseSearchCV.fit(self, X, y, groups, **fit_params)
924 refit_start_time = time.time()
925 if y is not None:
--> 926 self.best_estimator_.fit(X, y, **fit_params)
927 else:
928 self.best_estimator_.fit(X, **fit_params)
File D:\anaconda3\lib\site-packages\sklearn\tree\_classes.py:937, in DecisionTreeClassifier.fit(self, X, y, sample_weight, check_input, X_idx_sorted)
899 def fit(
900 self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated"
901 ):
902 """Build a decision tree classifier from the training set (X, y).
903
904 Parameters
(...)
934 Fitted estimator.
935 """
--> 937 super().fit(
938 X,
939 y,
940 sample_weight=sample_weight,
941 check_input=check_input,
942 X_idx_sorted=X_idx_sorted,
943 )
944 return self
File D:\anaconda3\lib\site-packages\sklearn\tree\_classes.py:203, in BaseDecisionTree.fit(self, X, y, sample_weight, check_input, X_idx_sorted)
200 self.n_outputs_ = y.shape[1]
202 if is_classification:
--> 203 check_classification_targets(y)
204 y = np.copy(y)
206 self.classes_ = []
File D:\anaconda3\lib\site-packages\sklearn\utils\multiclass.py:197, in check_classification_targets(y)
189 y_type = type_of_target(y)
190 if y_type not in [
191 "binary",
192 "multiclass",
(...)
195 "multilabel-sequences",
196 ]:
--> 197 raise ValueError("Unknown label type: %r" % y_type)
ValueError: Unknown label type: 'continuous'
IPW_est = ME_Single(single_dataset, r_model = 'OLS',
problearner_parameters = problearner_parameters,
truncate = 50,
target_policy=target_policy, control_policy = control_policy,
dim_state = 2, dim_mediator = 1,
MCMC = 50,
nature_decomp = True,
seed = 10,
method = 'IPW')
IPW_est.estimate_DE_ME()
IPW_est.est_DE, IPW_est.est_ME, IPW_est.est_TE,
(-0.23320671819000469, 0.0, -0.23320671819000469)
Robust_est = ME_Single(single_dataset, r_model = 'OLS',
problearner_parameters = problearner_parameters,
truncate = 50,
target_policy=target_policy, control_policy = control_policy,
dim_state = 2, dim_mediator = 1,
MCMC = 50,
nature_decomp = True,
seed = 10,
method = 'Robust')
Robust_est.estimate_DE_ME()
Robust_est.est_DE, Robust_est.est_ME, Robust_est.est_TE,
(-0.22743187935855116, -0.01633182578287201, -0.24376370514142318)
CPL: Single-Stage Policy Evaluation#
from causaldm.learners.CPL13.disc import QLearning
As an example, we use the Q-learning algorithm to evaluate policies based on the observed data, with the linear regression models defined as the following:
Using the code below, we evaluated two target polices (regimes). The first one is a fixed treatement regime that applies no treatment (Policy1), with an estimated value of .9999. Another is a fixed treatment regime that applies treatment all the time (Policy2), with an estimated value of .7646. Therefore, the treatment effect of Policy2 comparing to Policy1 is -.2353, implying that receiving IV input increase the mortality rate.
single_data.rename(columns = {'Died within 48H':'R', 'Glucose':'S1', 'PaO2_FiO2':'S2', 'IV Input':'A'}, inplace = True)
R = single_data['R'] #lower the better
S = single_data[['S1','S2']]
A = single_data[['A']]
# specify the model you would like to use
model_info = [{"model": "R~S1+S2+A+S1*A+S2*A",
'action_space':{'A':[0,1]}}]
# Evaluating the policy with no treatment
N=len(S)
regime = pd.DataFrame({'A':[0]*N}).set_index(S.index)
#evaluate the regime
QLearn = QLearning.QLearning()
QLearn.train(S, A, R, model_info, T=1, regime = regime, evaluate = True, mimic3_clip = True)
QLearn.predict_value(S)
0.9999999999999976
# Evaluating the policy that gives IV input at both stages
N=len(S)
regime = pd.DataFrame({'A':[1]*N}).set_index(S.index)
#evaluate the regime
QLearn = QLearning.QLearning()
QLearn.train(S, A, R, model_info, T=1, regime = regime, evaluate = True, mimic3_clip = True)
QLearn.predict_value(S)
0.7647336090193217
CPL: Single-Stage Policy Optimization#
Further, to find an optimal policy maximizing the expected value, we use the Q-learning algorithm again to do policy optimization. Using the regression model we specified above and the code in the following block, the estimated optimal policy is summarized as the following regime.
We would recommend \(A=0\) (IV_Input = 0) if \(-.0003*\textrm{Glucose}+.0012*\textrm{PaO2_FiO2}<.5633\)
Else, we would recommend \(A=1\) (IV_Input = 1).
Appling the estimated optimal regime to individuals in the observed data, we summarize the regime pattern for each patients in the following table:
# patients |
IV_Input |
---|---|
51 |
0 |
6 |
1 |
The estimated value of the estimated optimal policy is .9999. |
# initialize the learner
QLearn = QLearning.QLearning()
# train the policy
QLearn.train(S, A, R, model_info, T=1, mimic3_clip = True)
# get the summary of the fitted Q models using the following code
print("fitted model Q0:",QLearn.fitted_model[0].summary())
#4. recommend action
opt_d = QLearn.recommend_action(S).value_counts()
#5. get the estimated value of the optimal regime
V_hat = QLearn.predict_value(S)
print("opt_d:",opt_d)
print("opt value:",V_hat)
fitted model Q0: OLS Regression Results
==============================================================================
Dep. Variable: R R-squared: 0.182
Model: OLS Adj. R-squared: 0.102
Method: Least Squares F-statistic: 2.276
Date: Sun, 29 Sep 2024 Prob (F-statistic): 0.0607
Time: 14:43:10 Log-Likelihood: -17.634
No. Observations: 57 AIC: 47.27
Df Residuals: 51 BIC: 59.53
Df Model: 5
Covariance Type: nonrobust
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
Intercept 1.0000 0.318 3.141 0.003 0.361 1.639
S1 2.342e-17 0.002 1.42e-14 1.000 -0.003 0.003
S2 -3.442e-17 0.001 -6.81e-14 1.000 -0.001 0.001
A -0.5633 0.416 -1.354 0.182 -1.398 0.272
S1:A -0.0003 0.002 -0.136 0.892 -0.005 0.004
S2:A 0.0012 0.001 1.716 0.092 -0.000 0.003
==============================================================================
Omnibus: 16.164 Durbin-Watson: 1.728
Prob(Omnibus): 0.000 Jarque-Bera (JB): 18.591
Skew: -1.320 Prob(JB): 9.19e-05
Kurtosis: 3.927 Cond. No. 5.13e+03
==============================================================================
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
[2] The condition number is large, 5.13e+03. This might indicate that there are
strong multicollinearity or other numerical problems.
opt_d: A
0 51
1 6
dtype: int64
opt value: 0.9999999999999988
Reference#
[1]Robins, J. M. and Greenland, S. Identifiability and exchangeability for direct and indirect effects. Epidemiology, pp. 143–155, 1992.
[2]Hong, G. (2010). Ratio of mediator probability weighting for estimating natural direct and indirect effects. In Proceedings of the American Statistical Association, biometrics section (pp. 2401-2415).
[3] Tchetgen, E. J. T., & Shpitser, I. (2012). Semiparametric theory for causal mediation analysis: efficiency bounds, multiple robustness, and sensitivity analysis. Annals of statistics, 40(3), 1816.