Skip to content
Snippets Groups Projects
Commit be496e89 authored by aBotel's avatar aBotel
Browse files

Notebook: add pyHealth dataset parsing (incomplete)

Add a custom wrapper around the MIMIC3Dataset pyHealth class that
add functions to process the D_* text tables to extract the natural
language we need for DescEmb.

Processing functions are not complete. TODO actually write these.

Force MIMIC3Dataset to load from local files using global `DATA_DIR_`

Added skeleton for 2 pyHealth tasks readmission_pred_task, diagnosis_pred_task.
These are from the tutorial and they will be necessary to take the output of
MIMIC3DatasetWrapper and turn it into a test/train dataset.
TODO - incorporate the text data into the output of these tasks

TODO - hook up BERT model to some fake pasted strings of what we expect the
dataset to look like to verify it runs as we expect.
parent a4e89fe7
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id:3225717c-458b-45cb-9e8b-89ecd03b9d07 tags:
``` python
```
%% Cell type:markdown id:7ccba6af-1867-491d-b9d0-d99965738c64 tags:
# Imports
%% Cell type:code id:5da368cd-3045-4ec0-86fb-2ce15b5d1b92 tags:
``` python
# General includes.
import os
# Typing includes.
from typing import Dict, List, Optional
# Numerical includes.
import numpy as np
import pandas as pd
import torch
from torch import nn
# Pyhealth includes.
# pyHealth includes.
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.data import Patient, Visit, Event
```
%% Cell type:code id:95e714b1-ca5d-4d7b-9ace-6b4372adfd9f tags:
``` python
# Model imports
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM
```
%% Cell type:markdown id:cf83fa50-9dd0-467b-bd36-634795e4a09c tags:
# Globals
%% Cell type:code id:ebef3528-0396-459a-8112-b272089f5d67 tags:
``` python
USE_GPU_ = False
GPU_STRING_ = 'cuda'
DATA_DIR_ = '../data_input_path/mimic'
```
%% Cell type:markdown id:df0f8271-44b2-44a2-b7cc-cd7339a70e87 tags:
## Preprocessing
# Preprocessing
%% Cell type:code id:02624f53-6cee-4db0-8b15-e08b755a04f2 tags:
``` python
```
%% Cell type:markdown id:ebe2c2d9-c161-4b9d-8182-86384faea929 tags:
### Load BERT Model
%% Cell type:markdown id:39b0edec-1570-4640-9926-d908b456f957 tags:
See instructions here:
- https://pypi.org/project/pytorch-pretrained-bert/#examples
- https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial
- Could also get it from pytorch transformers library: https://pytorch.org/hub/huggingface_pytorch-transformers/
%% Cell type:code id:f40665f4-db92-4567-b0f9-001bd35e9284 tags:
``` python
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# TODO(botelho3)`bert-base-uncased` is big. Load `bert-tiny` instead from the filesystem?
# Model available at https://huggingface.co/prajjwal1/bert-tiny.
# model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
# Tokenized input
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)
# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']
# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
```
%% Output
100%|█████████████████████████████████████████████████████████| 231508/231508 [00:00<00:00, 877100.29B/s]
%% Cell type:code id:a7539b72-a871-4d52-947b-075b5c4be2ae tags:
``` python
```
%% Cell type:code id:de132943-2c3e-405b-b2f8-5d348472672c tags:
``` python
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()
# If you have a GPU, put everything on cuda
if USE_GPU_:
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda')
tokens_tensor = tokens_tensor.to(GPU_STRING_)
segments_tensors = segments_tensors.to(GPU_STRING_)
model.to(GPU_STRING_)
# Predict all tokens
with torch.no_grad():
predictions = model(tokens_tensor, segments_tensors)
# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'
```
%% Output
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_23031/2203343201.py in <module>
14
15 # confirm we were able to predict 'henson'
---> 16 predicted_index = torch.argmax(predictions[0, masked_index]).item()
17 predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
18 assert predicted_token == 'henson'
TypeError: tuple indices must be integers or slices, not tuple
%% Cell type:code id:02efb18a-626d-42d4-b19d-ba4b8f752a00 tags:
``` python
```
%% Cell type:code id:0e2d58de-bf9a-4b57-977f-6aa3ac75e84f tags:
``` python
# https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial
class BERTClassification(nn.Module):
def __init__ (self):
super(BERTClassification, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-cased')
self.bert_drop = nn.Dropout(0.4)
self.out = nn.Linear(768, 1)
def forward(self, ids, mask, token_type_ids):
_, pooledOut = self.bert(ids, attention_mask = mask,
token_type_ids=token_type_ids)
bertOut = self.bert_drop(pooledOut)
output = self.out(bertOut)
return output
```
%% Cell type:markdown id:8f0d5fa9-226e-4612-b28f-ed24de16399a tags:
### Load
### Load MIMIC III Data
%% Cell type:code id:b881e548-4d27-4725-8c47-b2612157929e tags:
``` python
class MIMIC3DatasetWrapper(MIMIC3Dataset):
''' Add extra tables to the MIMIC III dataset.
Some of the tables we need like "D_ICD_DIAGNOSES", "D_ITEMS", "D_ICD_PROCEDURES"
are not supported out of the box.
This class defines parsing methods to extract text data from these extra tables.
The text data is generally joined on the PATIENTID, HADMID, ITEMID to match the
pyHealth Vists class representation.
'''
# Skip init and defer to base class.
# Note the name has to match the table name exactly.
# See https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/datasets/mimic3.py#L71.
def parse_d_icd_diagnoses(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
# TODO(botelho3) fill this in to join the text descriptions to the visit.
return patients
# """Helper function which parses DIAGNOSES_ICD table.
# Will be called in `self.parse_tables()`
# Docs:
# - DIAGNOSES_ICD: https://mimic.mit.edu/docs/iii/tables/diagnoses_icd/
# Args:
# patients: a dict of `Patient` objects indexed by patient_id.
# Returns:
# The updated patients dict.
# Note:
# MIMIC-III does not provide specific timestamps in DIAGNOSES_ICD
# table, so we set it to None.
# """
# table = "DIAGNOSES_ICD"
# # read table
# df = pd.read_csv(
# os.path.join(self.root, f"{table}.csv"),
# dtype={"SUBJECT_ID": str, "HADM_ID": str, "ICD9_CODE": str},
# )
# # drop records of the other patients
# df = df[df["SUBJECT_ID"].isin(patients.keys())]
# # drop rows with missing values
# df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "ICD9_CODE"])
# # sort by sequence number (i.e., priority)
# df = df.sort_values(["SUBJECT_ID", "HADM_ID", "SEQ_NUM"], ascending=True)
# # group by patient and visit
# group_df = df.groupby("SUBJECT_ID")
# # parallel unit of diagnosis (per patient)
# def diagnosis_unit(p_id, p_info):
# events = []
# for v_id, v_info in p_info.groupby("HADM_ID"):
# for code in v_info["ICD9_CODE"]:
# event = Event(
# code=code,
# table=table,
# vocabulary="ICD9CM",
# visit_id=v_id,
# patient_id=p_id,
# )
# events.append(event)
# return events
# # parallel apply
# group_df = group_df.parallel_apply(
# lambda x: diagnosis_unit(x.SUBJECT_ID.unique()[0], x)
# )
# # summarize the results
# patients = self._add_events_to_patient_dict(patients, group_df)
# return patients
def parse_d_items(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
# TODO(botelho3) fill this in to join the text descriptions to the visit.
return patients
def parse_d_icd_procedures(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
# TODO(botelho3) fill this in to join the text descriptions to the visit.
return patients
```
%% Output
File "/var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_27199/845818540.py", line 69
def parse_d_items(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
^
IndentationError: expected an indented block
%% Cell type:code id:427a42a6-441e-438c-96f4-b38ba82fd192 tags:
``` python
mimic3base = MIMIC3DatasetWrapper(
#root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
root=os.path.join(os.getcwd(), DATA_DIR_),
tables=["D_ITEMS", "D_ICD_PROCEDURES", "PROCEDURES_ICD", "PRESCRIPTIONS", "LABEVENTS",], # "D_ICD_DIAGNOSES", "DIAGNOSES_ICD"
# map all NDC codes to ATC 3-rd level codes in these tables
# See https://en.wikipedia.org/wiki/Anatomical_Therapeutic_Chemical_Classification_System.
code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
)
mimic3base.stat()
mimic3base.info()
```
%% Output
Parsing PATIENTS and ADMISSIONS: 100%|███████████████████████████████████████████████| 100/100 [00:00<00:00, 540.68it/s]
Parsing PROCEDURES_ICD: 100%|███████████████████████████████████████████████████████| 113/113 [00:00<00:00, 5934.54it/s]
Parsing PRESCRIPTIONS: 100%|██████████████████████████████████████████████████████████| 122/122 [00:01<00:00, 99.08it/s]
Parsing LABEVENTS: 100%|██████████████████████████████████████████████████████████████| 129/129 [00:06<00:00, 21.36it/s]
Mapping codes: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 214.50it/s]
Statistics of base dataset (dev=False):
- Dataset: MIMIC3DatasetWrapper
- Number of patients: 100
- Number of visits: 129
- Number of visits per patient: 1.2900
- Number of events per visit in D_ITEMS: 0.0000
- Number of events per visit in D_ICD_PROCEDURES: 0.0000
- Number of events per visit in PROCEDURES_ICD: 3.9225
- Number of events per visit in PRESCRIPTIONS: 115.6667
- Number of events per visit in LABEVENTS: 479.1628
dataset.patients: patient_id -> <Patient>
<Patient>
- visits: visit_id -> <Visit>
- other patient-level info
<Visit>
- event_list_dict: table_name -> List[Event]
- other visit-level info
<Event>
- code: str
- other event-level info
%% Cell type:markdown id:b02b8e28-23ac-4ac1-a9b6-ae546e28c905 tags:
**Tasks**
Declare tasks for 2 of the 5 prediction tasks specified in the paper. We will create dataloaders for each task that contain the ICD codes and the raw text for each (patient, visit).
%% Cell type:code id:08692aad-db67-487b-9b26-dbde0ab6adce tags:
``` python
# The original authors tackled 5 tasks
# 1. readmission
# 2. mortality
# 3. an ICU stay exceeding three days
# 4. an ICU stay exceeding seven days
# 5. diagnosis prediction
def readmission_pred_task(patient):
"""
patient is a <pyhealth.data.Patient> object
"""
samples = []
# loop over all visits but the last one
for i in range(len(patient) - 1):
# visit and next_visit are both <pyhealth.data.Visit> objects
visit = patient[i]
next_visit = patient[i + 1]
# step 1: define the mortality_label
if next_visit.discharge_status not in [0, 1]:
mortality_label = 0
else:
mortality_label = int(next_visit.discharge_status)
# step 2: get code-based feature information
conditions = visit.get_code_list(table="DIAGNOSES_ICD")
procedures = visit.get_code_list(table="PROCEDURES_ICD")
drugs = visit.get_code_list(table="PRESCRIPTIONS")
labevents = visit.get_code_list(table="LABEVENTS")
# step 3: exclusion criteria: visits without condition, procedure, or drug
if len(conditions) * len(procedures) * len(drugs) == 0: continue
# step 4: assemble the samples
samples.append(
{
"visit_id": visit.visit_id,
"patient_id": patient.patient_id,
# the following keys can be the "feature_keys" or "label_key" for initializing downstream ML model
"conditions": conditions,
"procedures": procedures,
"drugs": drugs,
"label": mortality_label,
}
)
return samples
def mortality_pred_task(patient):
"""
patient is a <pyhealth.data.Patient> object
"""
samples = []
# loop over all visits but the last one
for i in range(len(patient) - 1):
# visit and next_visit are both <pyhealth.data.Visit> objects
visit = patient[i]
next_visit = patient[i + 1]
# step 1: define the mortality_label
if next_visit.discharge_status not in [0, 1]:
mortality_label = 0
else:
mortality_label = int(next_visit.discharge_status)
# step 2: get code-based feature information
conditions = visit.get_code_list(table="DIAGNOSES_ICD")
procedures = visit.get_code_list(table="PROCEDURES_ICD")
drugs = visit.get_code_list(table="PRESCRIPTIONS")
# step 3: exclusion criteria: visits without condition, procedure, or drug
if len(conditions) * len(procedures) * len(drugs) == 0: continue
# step 4: assemble the samples
samples.append(
{
"visit_id": visit.visit_id,
"patient_id": patient.patient_id,
# the following keys can be the "feature_keys" or "label_key" for initializing downstream ML model
"conditions": conditions,
"procedures": procedures,
"drugs": drugs,
"label": mortality_label,
}
)
return samples
# mimic3sample = mimic3base.set_task(task_fn=drug_recommendation_mimic3_fn) # use default task
# train_ds, val_ds, test_ds = split_by_patient(mimic3sample, [0.8, 0.1, 0.1])
```
%% Cell type:code id:96d047ff-6412-462c-9294-a5ba2d3bdba2 tags:
``` python
re_dataset = dataset.set_task(readmission_pred_task)
re_dataset.stat()
```
%% Cell type:code id:70586b66-a814-4898-892b-d1bdd339ce7b tags:
``` python
mor_dataset = dataset.set_task(mortality_pred_task)
mor_dataset.stat()
```
%% Cell type:code id:1056b60a-5861-4e47-b694-891053cc8470 tags:
``` python
# create dataloaders (torch.data.DataLoader)
train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)
test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
```
%% Cell type:markdown id:0f9b59b9-4bdb-44c4-b155-0a2a2d0f8d07 tags:
### Embed MIMIC III Data using BERT
%% Cell type:code id:1f5e0b1b-d75c-42ee-8d98-879f14b04606 tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment