{ "cells": [ { "cell_type": "markdown", "id": "99b219d1-2f7b-4ca3-8a20-1adeec8fab31", "metadata": {}, "source": [ "# MIMIC III (3-Stages)\n", "\n", "In this notebook, we conducted analysis on the MIMIC III data with 3 stages. We first analyzed the mediation effect and then evaluate the policy of interest and calculated the optimal policy. As informed by the causal structure learning, here we consider Glucose and PaO2_FiO2 as confounders/states, IV_Input as the action, SOFA as the mediator. " ] }, { "cell_type": "code", "execution_count": 1, "id": "51a5e576-f6f2-47c0-a96d-3935d0ee6ca8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Glucose_1PaO2_FiO2_1IV_Input_1SOFA_1Glucose_2PaO2_FiO2_2IV_Input_2SOFA_2Glucose_3PaO2_FiO2_3IV_Input_3SOFA_3Died_within_48H
0116.833333172.00000019125.666667364.00000015132.200000439.310339101
1120.000000170.00000015154.000000163.71428617164.000000174.000000161
2123.200000266.00000018126.40000094.00000017129.60000091.388889180
3168.500000260.83333303227.000000277.77777814257.857143191.9354821100
4115.000000146.00000017115.000000208.66666718115.000000210.0000001101
\n", "
" ], "text/plain": [ " Glucose_1 PaO2_FiO2_1 IV_Input_1 SOFA_1 Glucose_2 PaO2_FiO2_2 \\\n", "0 116.833333 172.000000 1 9 125.666667 364.000000 \n", "1 120.000000 170.000000 1 5 154.000000 163.714286 \n", "2 123.200000 266.000000 1 8 126.400000 94.000000 \n", "3 168.500000 260.833333 0 3 227.000000 277.777778 \n", "4 115.000000 146.000000 1 7 115.000000 208.666667 \n", "\n", " IV_Input_2 SOFA_2 Glucose_3 PaO2_FiO2_3 IV_Input_3 SOFA_3 \\\n", "0 1 5 132.200000 439.310339 1 0 \n", "1 1 7 164.000000 174.000000 1 6 \n", "2 1 7 129.600000 91.388889 1 8 \n", "3 1 4 257.857143 191.935482 1 10 \n", "4 1 8 115.000000 210.000000 1 10 \n", "\n", " Died_within_48H \n", "0 1 \n", "1 1 \n", "2 0 \n", "3 0 \n", "4 1 " ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "import pickle\n", "import numpy as np\n", "import pandas as pd\n", "file = open('mimic3_MDTR_data_dict_3stage_V2.pickle', 'rb')\n", "mimic3_MDTR = pickle.load(file)\n", "state, action, mediator, reward = mimic3_MDTR.values()\n", "reward.iloc[np.where(reward['Died_within_48H']==-1)[0],-1]=0 # change the discrete action to binary\n", "MDTR_data = pd.read_csv('mimic3_MDTR_3stage_V2.csv')\n", "MDTR_data.iloc[np.where(MDTR_data['Died_within_48H']==-1)[0],-1]=0\n", "MDTR_data.head()" ] }, { "cell_type": "markdown", "id": "c6b03650-db06-4e91-ac8b-bcc870187f00", "metadata": { "tags": [] }, "source": [ "# CEL: 3-Stage Mediation Analysis" ] }, { "cell_type": "markdown", "id": "c6de86b4-574a-4f15-a346-0a0122464ae7", "metadata": {}, "source": [ "Under the 3-stage setting, we are interested in analyzing the treatment effect on the final outcome Died_within_48H observed at the end of the study by comparing the target treatment regime that provides IV input at all three stages and the control treatment regime that does not provide any treatment. Using the Q-learning based estimator proposed in [1], we examine the natural direct and indirect effects of the target treatment regime based on observational data. With the code in the following blocks, the estimated effect components are summarized in the following, with the bootstrapped standard error included in the parenthesis:\n", "\n", "| NDE | NIE | TE |\n", "|-------|------|-------|\n", "| -.213(.695) | .156(.647) | -.057(.284) |" ] }, { "cell_type": "code", "execution_count": 2, "id": "681ab834-086a-438b-becc-0e101a17bde8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0 0.155758\n", " dtype: float64,\n", " 0 -0.212838\n", " dtype: float64,\n", " 0 -0.05708\n", " dtype: float64)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from causaldm.learners.CEL.MA import Mediated_QLearning\n", "MediatedQLearn = Mediated_QLearning.Mediated_QLearning()\n", "N=len(state)\n", "regime_control = pd.DataFrame({'IV_Input_1':[0]*N,'IV_Input_2':[0]*N, 'IV_Input_3':[0]*N}).set_index(state.index)\n", "regime_target = pd.DataFrame({'IV_Input_1':[1]*N,'IV_Input_2':[1]*N, 'IV_Input_3':[1]*N}).set_index(state.index)\n", "MediatedQLearn.train(state, action, mediator, reward, T=3, dim_state = 2, dim_mediator = 1, \n", " regime_target = regime_target, regime_control = regime_control,bootstrap = False)\n", "MediatedQLearn.est_NDE_NIE()\n", "MediatedQLearn.NIE, MediatedQLearn.NDE, MediatedQLearn.TE" ] }, { "cell_type": "code", "execution_count": 3, "id": "6c438fe5-fb92-4b6e-ae5e-3742590feb55", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([0.6474]), array([0.6946]), array([0.2835]))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "MediatedQLearn.train(state, action, mediator, reward, T=3, dim_state = 2, dim_mediator = 1, \n", " regime_target = regime_target, regime_control = regime_control,bootstrap = True, n_bs = 500)\n", "MediatedQLearn._predict_value_boots()\n", "MediatedQLearn.NIE_se, MediatedQLearn.NDE_se, MediatedQLearn.TE_se" ] }, { "cell_type": "markdown", "id": "c2b3bb90-76b4-4005-b626-574ad2759383", "metadata": {}, "source": [ "## CPL: 3-Stage Policy Evaluation" ] }, { "cell_type": "code", "execution_count": 4, "id": "92cc85ce-d2b7-445b-9950-e2a87ca7d2ad", "metadata": {}, "outputs": [], "source": [ "from causaldm.learners.CPL13.disc import QLearning" ] }, { "cell_type": "markdown", "id": "903e7f10-ce3d-4124-9d8f-4e1ab719e4b0", "metadata": {}, "source": [ "As an example, we use the **Q-learning** algorithm to evaluate policies based on the observed data, with the linear regression models defined as the following:\n", "\\begin{align}\n", "Q_1(s,a_1,\\boldsymbol{\\beta}) = &\\beta_{00}+\\beta_{01}*\\textrm{Glucose}_1+\\beta_{02}*\\textrm{PaO2_FiO2}_1\\\\\n", " &I(a_1=1)*\\{\\beta_{10}+\\beta_{11}*\\textrm{Glucose}_1+\\beta_{12}*\\textrm{PaO2_FiO2}_1\\},\\\\\n", "Q_2(s,a_2,\\boldsymbol{\\mu}) = &\\mu_{00}+\\mu_{01}*\\textrm{Glucose}_1+\\mu_{02}*\\textrm{PaO2_FiO2}_1+\\mu_{03}*\\textrm{SOFA}_1+\\\\\n", " &\\mu_{04}*\\textrm{Glucose}_2+\\mu_{05}*\\textrm{PaO2_FiO2}_2+\\\\\n", " &I(a_2=1)*\\{\\mu_{10}+\\mu_{11}*\\textrm{Glucose}_2+\\mu_{12}*\\textrm{PaO2_FiO2}_2+\\mu_{13}*\\textrm{SOFA}_1\\},\\\\\n", "Q_3(s,a_3,\\boldsymbol{\\theta}) = &\\theta_{00}+\\theta_{01}*\\textrm{Glucose}_1+\\theta_{02}*\\textrm{PaO2_FiO2}_1+\\theta_{03}*\\textrm{SOFA}_1+\\\\\n", " &\\theta_{04}*\\textrm{Glucose}_2+\\theta_{05}*\\textrm{PaO2_FiO2}_2+\\theta_{06}*\\textrm{SOFA}_2+\\\\\n", " &\\theta_{07}*\\textrm{Glucose}_3+\\theta_{08}*\\textrm{PaO2_FiO2}_3+\\\\\n", " &I(a_2=1)*\\{\\theta_{10}+\\theta_{11}*\\textrm{Glucose}_3+\\theta_{12}*\\textrm{PaO2_FiO2}_3\\}\n", "\\end{align}\n", "\n", "Using the code below, we evaluated two target polices (regimes). The first one is a fixed treatement regime that applies no treatment at all stages (Policy1), with an estimated value of .8991. Another is a fixed treatment regime that applies treatment at all stages (Policy2), with an estimated value of .8246. Therefore, the treatment effect of Policy2 comparing to Policy1 is -.0745, implying that receiving IV input increase the mortality rate." ] }, { "cell_type": "code", "execution_count": 5, "id": "88152436-d5eb-4ffd-ac45-d40fe19bdcda", "metadata": {}, "outputs": [], "source": [ "MDTR_data.rename(columns = {'Died_within_48H':'R',\n", " 'Glucose_1':'S1_1', 'Glucose_2':'S1_2','Glucose_3':'S1_3',\n", " 'PaO2_FiO2_1':'S3_1', 'PaO2_FiO2_2':'S3_2','PaO2_FiO2_3':'S3_3',\n", " 'SOFA_1':'S4_1', 'SOFA_2':'S4_2', 'SOFA_3':'S4_3',\n", " 'IV_Input_1':'A1','IV_Input_2':'A2', 'IV_Input_3':'A3'}, inplace = True)\n", "R = MDTR_data['R'] #lower the better\n", "S = MDTR_data[['S1_1','S1_2','S1_3','S3_1','S3_2','S3_3','S4_1','S4_2','S4_3']]\n", "A = MDTR_data[['A1','A2', 'A3']]\n", "# specify the model you would like to use\n", "model_info = [{\"model\": \"R~S1_1+S3_1+A1+S1_1*A1+S3_1*A1\",\n", " 'action_space':{'A1':[0,1]}},\n", " {\"model\": \"R~S1_1+S3_1+S4_1+S1_2+S3_2+A2+S1_2*A2+S3_2*A2+S4_1*A2\",\n", " 'action_space':{'A2':[0,1]}},\n", " {\"model\": \"R~S1_1+S3_1+S4_1+S1_2+S3_2+S4_2+S1_3+S3_3+A3+S1_3*A3+S3_3*A3+S4_2*A3\",\n", " 'action_space':{'A3':[0,1]}}]" ] }, { "cell_type": "code", "execution_count": 6, "id": "566b8bb8-80b1-4c29-8d92-3f273484094d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8990981941216631" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluating the policy with no treatment\n", "N=len(S)\n", "regime = pd.DataFrame({'A1':[0]*N,\n", " 'A2':[0]*N,\n", " 'A3':[0]*N}).set_index(S.index)\n", "#evaluate the regime\n", "QLearn = QLearning.QLearning()\n", "QLearn.train(S, A, R, model_info, T=3, regime = regime, evaluate = True, mimic3_clip = True)\n", "QLearn.predict_value(S)" ] }, { "cell_type": "code", "execution_count": 7, "id": "10f6441a-273c-4022-802d-e8d2df2336f6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8246053645689122" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluating the policy that gives IV input at both stages\n", "N=len(S)\n", "regime = pd.DataFrame({'A1':[1]*N,\n", " 'A2':[1]*N,\n", " 'A3':[1]*N}).set_index(S.index)\n", "#evaluate the regime\n", "QLearn = QLearning.QLearning()\n", "QLearn.train(S, A, R, model_info, T=3, regime = regime, evaluate = True, mimic3_clip = True)\n", "QLearn.predict_value(S)" ] }, { "cell_type": "markdown", "id": "e54de121-0b8b-4412-9678-3941846423c0", "metadata": {}, "source": [ "## CPL: 3-Stage Policy Optimization" ] }, { "cell_type": "markdown", "id": "e49f86ca-4a31-470d-b3b2-837b4c7dbcbf", "metadata": {}, "source": [ "Further, to find an optimal policy maximizing the expected value, we use the **Q-learning** algorithm again to do policy optimization. Using the regression model we specified above and the code in the following block, the estimated optimal policy is summarized as the following regime.\n", "\n", "- At stage 1:\n", " 1. We would recommend $A=0$ (IV_Input = 0) if $.0001*\\textrm{Glucose}_1+.0012*\\textrm{PaO2_FiO2}_1>.0551$\n", " 2. Else, we would recommend $A=1$ (IV_Input = 1).\n", "- At stage 2:\n", " 1. We would recommend $A=0$ (IV_Input = 0) if $.0002*\\textrm{Glucose}_2-.00001*\\textrm{PaO2_FiO2}_2+.0070*\\textrm{SOFA}_1<.0721$\n", " 2. Else, we would recommend $A=1$ (IV_Input = 1).\n", "- At stage 3:\n", " 1. We would recommend $A=0$ (IV_Input = 0) if $-.0005*\\textrm{Glucose}_2+.0008*\\textrm{PaO2_FiO2}_2-.0114*\\textrm{SOFA}_2<.2068$\n", " 2. Else, we would recommend $A=1$ (IV_Input = 1).\n", " \n", "Appling the estimated optimal regime to individuals in the observed data, we summarize the regime pattern for each patients in the following table:\n", "\n", "| # patients | IV_Input 1 | IV_Input 2 | IV_Input 3 |\n", "|------------|------------|------------|------------|\n", "| 23 | 1 | 1 | 0 |\n", "| 10 | 1 | 0 | 0 |\n", "| 5 | 1 | 1 | 1 |\n", "| 4 | 0 | 0 | 0 |\n", "| 4 | 0 | 1 | 0 |\n", "| 4 | 0 | 1 | 1 |\n", "| 2 | 0 | 0 | 1 |\n", "| 2 | 1 | 0 | 1 |\n", "\n", "The estimated value of the estimated optimal policy is **.9637**." ] }, { "cell_type": "code", "execution_count": 8, "id": "6a5a82ac-ef51-4899-b626-02cbee491120", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "opt_d: A3 A2 A1\n", "0 1 1 23\n", " 0 1 10\n", "1 1 1 5\n", "0 0 0 4\n", " 1 0 4\n", "1 1 0 4\n", " 0 0 2\n", " 1 2\n", "dtype: int64\n", "opt value: 0.9637185597953717\n" ] } ], "source": [ "# initialize the learner\n", "QLearn = QLearning.QLearning()\n", "# train the policy\n", "QLearn.train(S, A, R, model_info, T=3, mimic3_clip = True)\n", "# get the summary of the fitted Q models using the following code\n", "#print(\"fitted model Q0:\",QLearn.fitted_model[0].summary())\n", "#print(\"fitted model Q1:\",QLearn.fitted_model[1].summary())\n", "#4. recommend action\n", "opt_d = QLearn.recommend_action(S).value_counts()\n", "#5. get the estimated value of the optimal regime\n", "V_hat = QLearn.predict_value(S)\n", "print(\"opt_d:\",opt_d)\n", "print(\"opt value:\",V_hat)" ] }, { "cell_type": "markdown", "id": "87dcd03a-13db-4924-8340-3df258236487", "metadata": {}, "source": [ "## Reference\n", "\n", "[1] Zheng, W., & van der Laan, M. (2017). Longitudinal mediation analysis with time-varying mediators and exposures, with application to survival outcomes. Journal of causal inference, 5(2)." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 5 }