Learning representations for RL in Healthcare under a POMDP assumption, honoring the sequential nature that data wherewith data is generate. This paper, accepted to the proceedings of the 2020 Machine Learning for Healthcare Workshop at NeurIPS, empirically evaluates several recurrent autoencoding architectures to assess the quality of the internal representations each learn.
The motivation for this systematic analysis is that most prior work developing RL solutions for healthcare neglect to rigorously define state representations that respect the partial or sequential nature of the data generating process. We evaluate several recurrent autoencoding architectures, trained by predicting the subsequent physiological observation, by investigating the quality of the internal learned representations of patient state as well as develop treatment policies from them.
If you use this code in your research, please cite the following publication link to PMLR version:
@inproceedings{killian2020empirical,
title={An Empirical Study of Representation Learning for Reinforcement Learning in Healthcare},
author={Killian, Taylor W and Zhang, Haoran and Subramanian, Jayakumar and Fatemi, Mehdi and Ghassemi, Marzyeh},
booktitle={Machine Learning for Health},
pages={139--160},
year={2020},
organization={PMLR}
}
This paper can also be found on arxiv: https://arxiv.org/abs/2011.11235
Run the following commands to clone this repo and create the Conda environment
git clone https://github.com/MLforHealth/rl_representations.git
cd rl_representations/
conda env create -f environment.yml
conda activate rl4h_rep
The data used to develop, run and evaluate our experiments is extracted from the MIMIC-III database, based on the Sepsis cohort used by Komorowski, et al (2018). We replicate and refine the code to extract this cohort at the following repository https://github.com/microsoft/mimic_sepsis. For replication purposes, save both the z-normalized and "RAW" features.
This extracted cohort is for general purpose use, for the usage in this paper it required additional preprocessing which we outline here.
- Run
scripts/compute_acuity_scores.py
to compute additional acuity scores with the raw patient features. - Run
scripts/split_sepsis_cohort.py
to compose a training, validation and testing split as well as remove binary or demographic information from the temporal patient features. This script also organizes the patient data into convenient trajectory formats for easier use with sequential or recurrent models. - Run
scripts/create_buffers.py
to create replay buffers needed in the next step. Make sure to set the location where you want the resulting files to be saved inconfigs/config_behavCloning.yaml
. - Use Behavior Cloning on the data provided by the previous step to develop a baseline "expert" policy for use in training and evaluating RL policies from the learned patient representations.
- Two options are possible: Use either
scripts/train_behavCloning_with_config_file.py
if you want parameters to be loaded only from the config files, orscripts/train_behavCloning_with_command_line_args.py
to use command line arguments. Note, that in the second option, the parameters from the config files are overwritten by command line arguments, and if you don't specify command lines arguments then the default values for these arguments from the code will be used, not the values from the config files. For the first option, the relevant configs areconfigs/config_behavCloning.yaml
andconfigs/common.yaml
. For the second option, an example of how this can be done is provided inslurm_scripts/slurm_build_BC.py
-->slurm_scripts/slurm_bc_exp
. - This will generate either
behav_policy_file
orbehav_policy_file_wDemo
which is used internal to Steps 2 and 3 below.
- Two options are possible: Use either
- In
configs/common.yaml
, update the paths at the bottom of the file. - Select a
model
- one of [AE
,AIS
,CDE
,DDM
,DST
,ODERNN
,RNN
] - Run
python slurm_build_{model}.py
to output a text file containing an array of launch arguments toscripts/train_model.py
that were used to generate the results from the paper. - If you are using a Slurm cluster, you can run
sbatch slurm_{model}_exp
after altering the Slurm header to run a Slurm task array where each job corresponds to training a single model with one line of arguments from the generated file. - Otherwise, models can also be trained on an ad-hoc basis by manually calling
python scripts/train_model.py
with the arguments shown in the generated text file. - The configuration files in
config_sepsis_{model}.yaml
contain the hyperparameters used in the paper. These can be changed if desired.
RL policies are learned using the discretized form of Batch Constrained Q-learning from Fujimoto, et al (2019). These policies are learned as the final part of the previous Step 2 inside of scripts/train_model.py
. The policies are evaluated intermittently using a form of Weighted Importance Sampling (procedure found in scripts/dBCQ_utils.py
, line 284).
We compare various evaluations of the learned policies using the iPython notebook: notebooks/Evaluate\ dBCQ\ policies.ipynb
.
We aggregate the results of all model training and the analysis of the learned representations in the notebook notebooks/Aggregate\ Next\ Step\ Results.ipynb
.
See requirements.txt
- Taylor W. Killian @twkillian
- Haoran Zhang @hzhang0
- Jaykumar Subramanian
- Mehdi Fatemi
- Marzyeh Ghassemi
The source code and documentation are licensed under the terms of the MIT License
Please don't hesitate to log an issue and we'll be as prompt as possible when responding.