Reproducibility Project for CS598 DL4H in Spring 2023: Using recurrent neural network models for early detection of heart failure onset
This repository is the official implementation of the Reproducibility Project for CS598 DL4H in Spring 2023: Using recurrent neural network models for early detection of heart failure onset.
In this Jupyter notebook, we attempt to recreate the findings from the 'Using recurrent neural network models for early detection of heart failure onset' by Edward Choi, Andy Schuetz, Walter F Stewart, and Jimeng Sun. We create and train deep learning models to detect the early onset of heart failure using patient EHR data. The models are based on LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Units) architectures, which are types of recurrent neural networks (RNNs) that can capture dependencies in sequential data.
Requirements
We utilized an AWS Deep Learning AMI to conduct our reproducibility project on a p3dn.24xlarge instance (detailed specifications provided below).
AMI: Deep Learning AMI GPU TensorFlow 2.12.0 (Amazon Linux 2) 20230324
ami-0649417d1ede3c91a (64-bit(x86))
Virtualization: hvm | ENA enabled: true | Root device type: ebs
Instance Type: p3dn.24xlarge
GPUs - Tesla V100: 8
GPU Memory (GB): 256
vCPUs: 96.
Memory (GB): 768
After cloning this repository, download the dataset, install the required packages, configure, start Jupyter Notebook, and navigate to the reproducibility-project-RNN-early-detection-heart-failure-onset.ipynb Jupyter notebook to load it.
# Download MIMIC-III dataset
wget -r -N -c -np --user <username> --ask-password -i physionet-downloads.txt
# Navigate to ~/physionet.org/files/mimiciii/1.4 and run:
gzip -d *
# Install PyHealth and required libraries
pip install pyhealth
pip install gputil psutil humanize memory_profiler scikit-learn-intelex imblearn
# Configure and run Jupyter
openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout ~/ssl/mykey.key -out ~/ssl/mycert.pem
jupyter notebook password
export TF_CPP_MIN_LOG_LEVEL=2
jupyter notebook --certfile=~/ssl/mycert.pem --keyfile ~/ssl/mykey.key
Training
To train an RNN on the MIMIC-III dataset, we provide a train_rnn() function which takes an rnn_type argument, which specifies the type of RNN to use: GRU or LSTM. The function performs the k-fold cross-validation procedure with 10 splits using the StratifiedKFold function from sklearn.
The following hyperparameters are configured:
- hidden_nodes is set to 256
- kernel_regularizer is set to L2 regularization with a strength of 0.001
- epochs is set to 100
- batch_size is set to 10
- optimizer is set to Adam with a learning rate of 0.01.
- patience is set to 5
In addition, we train Logistic Regression, SVM, MLP, KNN, and Decision Tree models by aggregating the one-hot encoded features, normalizing, and flattening into one-dimensional arrays. We then split the data into training and test sets while ensuring similar distribution of classes in label.
Evaluation & Results
We output the mean accuracy, AUC and confusion matrix of the trained RNN models on the validation set of each fold. For the remaining models, we output accuracy and AUC.
Our model achieves the following performance:
Model | w/ duration, 12mo observation window and 6mo prediction window | w/ duration, 18mo observation window and 0mo prediction window | w/o duration, 12mo observation window and 6mo prediction window | w/o duration, 18mo observation window and 0mo prediction window |
---|---|---|---|---|
AUC | AUC | AUC | AUC | |
GRU | 0.690 | 0.727 | 0.688 | 0.790 |
LSTM | 0.618 | 0.737 | 0.710 | 0.644 |
Logistic Regression | 0.879 | 0.896 | 0.901 | 0.905 |
SVM | 0.909 | 0.933 | 0.926 | 0.929 |
MLP | 0.893 | 0.919 | 0.889 | 0.916 |
KNN | 0.836 | 0.843 | 0.849 | 0.845 |
Decision Tree | 0.864 | 0.894 | 0.879 | 0.895 |
Note. The results obtained from the confusion matrix are the mean values calculated over the batch size.
Confusion Matrix for GRU model w/ duration, 12mo observation window and 6mo prediction window | Confusion Matrix for GRU model w/o duration, 12mo observation window and 6mo prediction window |
---|---|
![]() |
![]() |
Confusion Matrix for GRU model w/ duration, 18mo observation window and 0mo prediction window | Confusion Matrix for GRU model w/o duration, 18mo observation window and 0mo prediction window |
---|---|
![]() |
![]() |
Contributing
We welcome contributions to this project! Whether you want to report a bug, request a feature, or submit a pull request, we appreciate your involvement.