Skip to content
Snippets Groups Projects

Reproducibility Project for CS598 DL4H in Spring 2023: Using recurrent neural network models for early detection of heart failure onset

This repository is the official implementation of the Reproducibility Project for CS598 DL4H in Spring 2023: Using recurrent neural network models for early detection of heart failure onset.

In this Jupyter notebook, we attempt to recreate the findings from the 'Using recurrent neural network models for early detection of heart failure onset' by Edward Choi, Andy Schuetz, Walter F Stewart, and Jimeng Sun. We create and train deep learning models to detect the early onset of heart failure using patient EHR data. The models are based on LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Units) architectures, which are types of recurrent neural networks (RNNs) that can capture dependencies in sequential data.

Requirements

We utilized an AWS Deep Learning AMI to conduct our reproducibility project on a p3dn.24xlarge instance (detailed specifications provided below).

AMI: Deep Learning AMI GPU TensorFlow 2.12.0 (Amazon Linux 2) 20230324
ami-0649417d1ede3c91a (64-bit(x86))
Virtualization: hvm | ENA enabled: true | Root device type: ebs

Instance Type: p3dn.24xlarge
GPUs - Tesla V100: 8
GPU Memory (GB): 256
vCPUs: 96. Memory (GB): 768

After cloning this repository, download the dataset, install the required packages, configure, start Jupyter Notebook, and navigate to the reproducibility-project-RNN-early-detection-heart-failure-onset.ipynb Jupyter notebook to load it.

# Download MIMIC-III dataset
wget -r -N -c -np --user <username> --ask-password -i physionet-downloads.txt

# Navigate to ~/physionet.org/files/mimiciii/1.4 and run:
gzip -d *

# Install PyHealth and required libraries
pip install pyhealth
pip install gputil psutil humanize memory_profiler scikit-learn-intelex imblearn

# Configure and run Jupyter
openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout ~/ssl/mykey.key -out ~/ssl/mycert.pem
jupyter notebook password
export TF_CPP_MIN_LOG_LEVEL=2
jupyter notebook --certfile=~/ssl/mycert.pem --keyfile ~/ssl/mykey.key

Training

To train an RNN on the MIMIC-III dataset, we provide a train_rnn() function which takes an rnn_type argument, which specifies the type of RNN to use: GRU or LSTM. The function performs the k-fold cross-validation procedure with 10 splits using the StratifiedKFold function from sklearn.

The following hyperparameters are configured:

  • hidden_nodes is set to 256
  • kernel_regularizer is set to L2 regularization with a strength of 0.001
  • epochs is set to 100
  • batch_size is set to 10
  • optimizer is set to Adam with a learning rate of 0.01.
  • patience is set to 5

In addition, we train Logistic Regression, SVM, MLP, KNN, and Decision Tree models by aggregating the one-hot encoded features, normalizing, and flattening into one-dimensional arrays. We then split the data into training and test sets while ensuring similar distribution of classes in label.

Evaluation & Results

We output the mean accuracy, AUC and confusion matrix of the trained RNN models on the validation set of each fold. For the remaining models, we output accuracy and AUC.

Our model achieves the following performance:

Model w/ duration, 12mo observation window and 6mo prediction window w/ duration, 18mo observation window and 0mo prediction window w/o duration, 12mo observation window and 6mo prediction window w/o duration, 18mo observation window and 0mo prediction window
AUC AUC AUC AUC
GRU 0.690 0.727 0.688 0.790
LSTM 0.618 0.737 0.710 0.644
Logistic Regression 0.879 0.896 0.901 0.905
SVM 0.909 0.933 0.926 0.929
MLP 0.893 0.919 0.889 0.916
KNN 0.836 0.843 0.849 0.845
Decision Tree 0.864 0.894 0.879 0.895

Note. The results obtained from the confusion matrix are the mean values calculated over the batch size.

Confusion Matrix for GRU model w/ duration, 12mo observation window and 6mo prediction window Confusion Matrix for GRU model w/o duration, 12mo observation window and 6mo prediction window
Confusion Matrix - w/ duration, 12mo observation window and 6mo prediction window Confusion Matrix - w/o duration, 12mo observation window and 6mo prediction window
Confusion Matrix for GRU model w/ duration, 18mo observation window and 0mo prediction window Confusion Matrix for GRU model w/o duration, 18mo observation window and 0mo prediction window
Confusion Matrix - w/ duration, 18mo observation window and 0mo prediction window Confusion Matrix - w/o duration, 18mo observation window and 0mo prediction window

Contributing

We welcome contributions to this project! Whether you want to report a bug, request a feature, or submit a pull request, we appreciate your involvement.