Skip to content
Snippets Groups Projects
Commit 2d2a37d3 authored by aBotel's avatar aBotel
Browse files

Add some additional text LUT code and DataLoader skeleton.

parent 25d4d101
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id:7ccba6af-1867-491d-b9d0-d99965738c64 tags: %% Cell type:markdown id:7ccba6af-1867-491d-b9d0-d99965738c64 tags:
# Imports # Imports
%% Cell type:code id:5da368cd-3045-4ec0-86fb-2ce15b5d1b92 tags: %% Cell type:code id:5da368cd-3045-4ec0-86fb-2ce15b5d1b92 tags:
``` python ``` python
# General includes. # General includes.
import os import os
# Typing includes. # Typing includes.
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
# Numerical includes. # Numerical includes.
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
from torch import nn from torch import nn
# pyHealth includes. # pyHealth includes.
from pyhealth.datasets import MIMIC3Dataset from pyhealth.datasets import MIMIC3Dataset
from pyhealth.data import Patient, Visit, Event from pyhealth.data import Patient, Visit, Event
``` ```
%% Cell type:code id:95e714b1-ca5d-4d7b-9ace-6b4372adfd9f tags: %% Cell type:code id:95e714b1-ca5d-4d7b-9ace-6b4372adfd9f tags:
``` python ``` python
# Model imports # Model imports
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM
``` ```
%% Cell type:markdown id:cf83fa50-9dd0-467b-bd36-634795e4a09c tags: %% Cell type:markdown id:cf83fa50-9dd0-467b-bd36-634795e4a09c tags:
# Globals # Globals
%% Cell type:code id:ebef3528-0396-459a-8112-b272089f5d67 tags: %% Cell type:code id:ebef3528-0396-459a-8112-b272089f5d67 tags:
``` python ``` python
USE_GPU_ = False USE_GPU_ = False
GPU_STRING_ = 'cuda' GPU_STRING_ = 'cuda'
DATA_DIR_ = '../data_input_path/mimic' DATA_DIR_ = '../data_input_path/mimic'
``` ```
%% Cell type:markdown id:df0f8271-44b2-44a2-b7cc-cd7339a70e87 tags: %% Cell type:markdown id:df0f8271-44b2-44a2-b7cc-cd7339a70e87 tags:
# Preprocessing # Preprocessing
%% Cell type:code id:02624f53-6cee-4db0-8b15-e08b755a04f2 tags: %% Cell type:code id:02624f53-6cee-4db0-8b15-e08b755a04f2 tags:
``` python ``` python
``` ```
%% Cell type:markdown id:ebe2c2d9-c161-4b9d-8182-86384faea929 tags: %% Cell type:markdown id:ebe2c2d9-c161-4b9d-8182-86384faea929 tags:
### Load BERT Model ### Load BERT Model
%% Cell type:markdown id:39b0edec-1570-4640-9926-d908b456f957 tags: %% Cell type:markdown id:39b0edec-1570-4640-9926-d908b456f957 tags:
See instructions here: See instructions here:
- https://pypi.org/project/pytorch-pretrained-bert/#examples - https://pypi.org/project/pytorch-pretrained-bert/#examples
- https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial - https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial
- Could also get it from pytorch transformers library: https://pytorch.org/hub/huggingface_pytorch-transformers/ - Could also get it from pytorch transformers library: https://pytorch.org/hub/huggingface_pytorch-transformers/
%% Cell type:code id:f40665f4-db92-4567-b0f9-001bd35e9284 tags: %% Cell type:code id:f40665f4-db92-4567-b0f9-001bd35e9284 tags:
``` python ``` python
# Load pre-trained model tokenizer (vocabulary) # Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# TODO(botelho3)`bert-base-uncased` is big. Load `bert-tiny` instead from the filesystem? # TODO(botelho3)`bert-base-uncased` is big. Load `bert-tiny` instead from the filesystem?
# Model available at https://huggingface.co/prajjwal1/bert-tiny. # Model available at https://huggingface.co/prajjwal1/bert-tiny.
# model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None) # model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
# Tokenized input # Tokenized input
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text) tokenized_text = tokenizer.tokenize(text)
# Mask a token that we will try to predict back with `BertForMaskedLM` # Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 8 masked_index = 8
tokenized_text[masked_index] = '[MASK]' tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]'] assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']
# Convert token to vocabulary indices # Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper) # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# Convert inputs to PyTorch tensors # Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens]) tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids]) segments_tensors = torch.tensor([segments_ids])
``` ```
%% Output %% Output
100%|█████████████████████████████████████████████████████████| 231508/231508 [00:00<00:00, 877100.29B/s] 100%|█████████████████████████████████████████████████████████| 231508/231508 [00:00<00:00, 877100.29B/s]
%% Cell type:code id:a7539b72-a871-4d52-947b-075b5c4be2ae tags: %% Cell type:code id:a7539b72-a871-4d52-947b-075b5c4be2ae tags:
``` python ``` python
``` ```
%% Cell type:code id:de132943-2c3e-405b-b2f8-5d348472672c tags: %% Cell type:code id:de132943-2c3e-405b-b2f8-5d348472672c tags:
``` python ``` python
# Load pre-trained model (weights) # Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased') model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval() model.eval()
# If you have a GPU, put everything on cuda # If you have a GPU, put everything on cuda
if USE_GPU_: if USE_GPU_:
tokens_tensor = tokens_tensor.to(GPU_STRING_) tokens_tensor = tokens_tensor.to(GPU_STRING_)
segments_tensors = segments_tensors.to(GPU_STRING_) segments_tensors = segments_tensors.to(GPU_STRING_)
model.to(GPU_STRING_) model.to(GPU_STRING_)
# Predict all tokens # Predict all tokens
with torch.no_grad(): with torch.no_grad():
predictions = model(tokens_tensor, segments_tensors) predictions = model(tokens_tensor, segments_tensors)
# confirm we were able to predict 'henson' # confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item() predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson' assert predicted_token == 'henson'
``` ```
%% Output %% Output
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
TypeError Traceback (most recent call last) TypeError Traceback (most recent call last)
/var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_23031/2203343201.py in <module> /var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_23031/2203343201.py in <module>
14 14
15 # confirm we were able to predict 'henson' 15 # confirm we were able to predict 'henson'
---> 16 predicted_index = torch.argmax(predictions[0, masked_index]).item() ---> 16 predicted_index = torch.argmax(predictions[0, masked_index]).item()
17 predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] 17 predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
18 assert predicted_token == 'henson' 18 assert predicted_token == 'henson'
TypeError: tuple indices must be integers or slices, not tuple TypeError: tuple indices must be integers or slices, not tuple
%% Cell type:code id:02efb18a-626d-42d4-b19d-ba4b8f752a00 tags: %% Cell type:code id:02efb18a-626d-42d4-b19d-ba4b8f752a00 tags:
``` python ``` python
``` ```
%% Cell type:code id:0e2d58de-bf9a-4b57-977f-6aa3ac75e84f tags: %% Cell type:code id:0e2d58de-bf9a-4b57-977f-6aa3ac75e84f tags:
``` python ``` python
# https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial # https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial
class BERTClassification(nn.Module): class BERTClassification(nn.Module):
def __init__ (self): def __init__ (self):
super(BERTClassification, self).__init__() super(BERTClassification, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-cased') self.bert = BertModel.from_pretrained('bert-base-cased')
self.bert_drop = nn.Dropout(0.4) self.bert_drop = nn.Dropout(0.4)
self.out = nn.Linear(768, 1) self.out = nn.Linear(768, 1)
def forward(self, ids, mask, token_type_ids): def forward(self, ids, mask, token_type_ids):
_, pooledOut = self.bert(ids, attention_mask = mask, _, pooledOut = self.bert(ids, attention_mask = mask,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
bertOut = self.bert_drop(pooledOut) bertOut = self.bert_drop(pooledOut)
output = self.out(bertOut) output = self.out(bertOut)
return output return output
``` ```
%% Cell type:markdown id:8f0d5fa9-226e-4612-b28f-ed24de16399a tags: %% Cell type:markdown id:8f0d5fa9-226e-4612-b28f-ed24de16399a tags:
### Load MIMIC III Data ### Load MIMIC III Data
%% Cell type:code id:b881e548-4d27-4725-8c47-b2612157929e tags: %% Cell type:code id:b881e548-4d27-4725-8c47-b2612157929e tags:
``` python ``` python
class MIMIC3DatasetWrapper(MIMIC3Dataset): class MIMIC3DatasetWrapper(MIMIC3Dataset):
''' Add extra tables to the MIMIC III dataset. ''' Add extra tables to the MIMIC III dataset.
Some of the tables we need like "D_ICD_DIAGNOSES", "D_ITEMS", "D_ICD_PROCEDURES" Some of the tables we need like "D_ICD_DIAGNOSES", "D_ITEMS", "D_ICD_PROCEDURES"
are not supported out of the box. are not supported out of the box.
This class defines parsing methods to extract text data from these extra tables. This class defines parsing methods to extract text data from these extra tables.
The text data is generally joined on the PATIENTID, HADMID, ITEMID to match the The text data is generally joined on the PATIENTID, HADMID, ITEMID to match the
pyHealth Vists class representation. pyHealth Vists class representation.
''' '''
# We need to add storage for text-based lookup tables here. # We need to add storage for text-based lookup tables here.
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._valid_text_tables = ["D_ICD_DIAGNOSES", "D_ITEMS", "D_ICD_PROCEDURES", "D_LABITEMS"] self._valid_text_tables = ["D_ICD_DIAGNOSES", "D_ITEMS", "D_ICD_PROCEDURES", "D_LABITEMS"]
self._text_descriptions = {x: {} for x in self._valid_text_tables} self._text_descriptions = {x: {} for x in self._valid_text_tables}
self._text_luts = {x: {} for x in self._valid_text_tables}
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def get_all_tables(self) -> List[str]: def get_all_tables(self) -> List[str]:
return list(self._text_descriptions.keys()) return list(self._text_descriptions.keys())
def get_text_dict(self, table_name: str) -> Dict[str, Dict[Any, Any]]: def get_text_dict(self, table_name: str) -> Dict[str, Dict[Any, Any]]:
return self._text_descriptions.get(table_name) return self._text_descriptions.get(table_name)
def set_text_lut(self, table_name: str, lut: Dict[Any, Any]) -> None:
self._text_luts[table_name] = lut
def get_text_lut(self, table_name: str) -> Dict[Any, Any]:
return self._text_luts[table_name]
# Note the name has to match the table name exactly. # Note the name has to match the table name exactly.
# See https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/datasets/mimic3.py#L71. # See https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/datasets/mimic3.py#L71.
def parse_d_icd_diagnoses(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: def parse_d_icd_diagnoses(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses D_ICD_DIAGNOSIS table. """Helper function which parses D_ICD_DIAGNOSIS table.
Will be called in `self.parse_tables()` Will be called in `self.parse_tables()`
Docs: Docs:
- D_ICD_DIAGNOSIS: https://mimic.mit.edu/docs/iii/tables/d_icd_diagnoses/ - D_ICD_DIAGNOSIS: https://mimic.mit.edu/docs/iii/tables/d_icd_diagnoses/
Args: Args:
patients: a dict of `Patient` objects indexed by patient_id. patients: a dict of `Patient` objects indexed by patient_id.
Returns: Returns:
The updated patients dict. The updated patients dict.
Note: Note:
N/A N/A
""" """
table = "D_ICD_DIAGNOSES" table = "D_ICD_DIAGNOSES"
print(f"Parsing {table}") print(f"Parsing {table}")
assert(table in self._valid_text_tables) assert(table in self._valid_text_tables)
# read table # read table
df = pd.read_csv( df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"), os.path.join(self.root, f"{table}.csv"),
usecols=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"], usecols=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"],
dtype={"ICD9_CODE": str, "SHORT_TITLE": str, "LONG_TITLE": str} dtype={"ICD9_CODE": str, "SHORT_TITLE": str, "LONG_TITLE": str}
) )
# drop rows with missing values # drop rows with missing values
df = df.dropna(subset=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"]) df = df.dropna(subset=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"])
# sort by sequence number (i.e., priority) # sort by sequence number (i.e., priority)
df = df.sort_values(["ICD9_CODE"], ascending=True) df = df.sort_values(["ICD9_CODE"], ascending=True)
# print(df.head()) # print(df.head())
self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split') self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split')
# We haven't altered the patients array, just return it. # We haven't altered the patients array, just return it.
return patients return patients
def parse_d_labitems(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: def parse_d_labitems(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses D_LABITEMS table. """Helper function which parses D_LABITEMS table.
Will be called in `self.parse_tables()` Will be called in `self.parse_tables()`
Docs: Docs:
- D_LABITEMS: https://mimic.mit.edu/docs/iii/tables/d_labitems/ - D_LABITEMS: https://mimic.mit.edu/docs/iii/tables/d_labitems/
Args: Args:
patients: a dict of `Patient` objects indexed by patient_id. patients: a dict of `Patient` objects indexed by patient_id.
Returns: Returns:
The updated patients dict. The updated patients dict.
Note: Note:
N/A N/A
""" """
table = "D_LABITEMS" table = "D_LABITEMS"
print(f"Parsing {table}") print(f"Parsing {table}")
assert(table in self._valid_text_tables) assert(table in self._valid_text_tables)
# read table # read table
df = pd.read_csv( df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"), os.path.join(self.root, f"{table}.csv"),
usecols=["ITEMID", "LABEL", "CATEGORY", "FLUID"], usecols=["ITEMID", "LABEL", "CATEGORY", "FLUID"],
dtype={"ITEMID": str, "LABEL": str, "CATEGORY": str, "FLUID": str} dtype={"ITEMID": str, "LABEL": str, "CATEGORY": str, "FLUID": str}
) )
# drop rows with missing values # drop rows with missing values
df = df.dropna(subset=["ITEMID", "LABEL", "CATEGORY", "FLUID"]) df = df.dropna(subset=["ITEMID", "LABEL", "CATEGORY", "FLUID"])
# sort by sequence number (i.e., priority) # sort by sequence number (i.e., priority)
df = df.sort_values(["ITEMID"], ascending=True) df = df.sort_values(["ITEMID"], ascending=True)
self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split') self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split')
# We haven't altered the patients array, just return it. # We haven't altered the patients array, just return it.
return patients return patients
def parse_d_items(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: def parse_d_items(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
# TODO(botelho3) - Note this may not be totally useable because the ITEMID # TODO(botelho3) - Note this may not be totally useable because the ITEMID
# uinqiue key only links to these tables using ITEMID # uinqiue key only links to these tables using ITEMID
# - INPUTEVENTS_MV # - INPUTEVENTS_MV
# - OUTPUTEVENTS on ITEMID # - OUTPUTEVENTS on ITEMID
# - PROCEDUREEVENTS_MV on ITEMID # - PROCEDUREEVENTS_MV on ITEMID
# #
# Not to the tables we want e.g. # Not to the tables we want e.g.
"""Helper function which parses D_ITEMS table. """Helper function which parses D_ITEMS table.
Will be called in `self.parse_tables()` Will be called in `self.parse_tables()`
Docs: Docs:
- D_ITEMS: https://mimic.mit.edu/docs/iii/tables/d_items/ - D_ITEMS: https://mimic.mit.edu/docs/iii/tables/d_items/
Args: Args:
patients: a dict of `Patient` objects indexed by patient_id. patients: a dict of `Patient` objects indexed by patient_id.
Returns: Returns:
The updated patients dict. The updated patients dict.
Note: Note:
N/A N/A
""" """
table = "D_ITEMS" table = "D_ITEMS"
print(f"Parsing {table}") print(f"Parsing {table}")
assert(table in self._valid_text_tables) assert(table in self._valid_text_tables)
# read table # read table
df = pd.read_csv( df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"), os.path.join(self.root, f"{table}.csv"),
usecols=["ITEMID", "LABEL", "CATEGORY"], usecols=["ITEMID", "LABEL", "CATEGORY"],
dtype={"ITEMID": str, "LABEL": str, "CATEGORY": str} dtype={"ITEMID": str, "LABEL": str, "CATEGORY": str}
) )
# drop rows with missing values # drop rows with missing values
df = df.dropna(subset=["ITEMID", "LABEL", "CATEGORY"]) df = df.dropna(subset=["ITEMID", "LABEL", "CATEGORY"])
# sort by sequence number (i.e., priority) # sort by sequence number (i.e., priority)
df = df.sort_values(["ITEMID"], ascending=True) df = df.sort_values(["ITEMID"], ascending=True)
self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split') self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split')
# We haven't altered the patients array, just return it. # We haven't altered the patients array, just return it.
return patients return patients
def parse_d_icd_procedures(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: def parse_d_icd_procedures(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses D_ICD_PROCEDURES table. """Helper function which parses D_ICD_PROCEDURES table.
Will be called in `self.parse_tables()` Will be called in `self.parse_tables()`
Docs: Docs:
- D_ICD_PROCEDURES: https://mimic.mit.edu/docs/iii/tables/d_icd_procedures/ - D_ICD_PROCEDURES: https://mimic.mit.edu/docs/iii/tables/d_icd_procedures/
Args: Args:
patients: a dict of `Patient` objects indexed by patient_id. patients: a dict of `Patient` objects indexed by patient_id.
Returns: Returns:
The updated patients dict. The updated patients dict.
Note: Note:
N/A N/A
""" """
table = "D_ICD_PROCEDURES" table = "D_ICD_PROCEDURES"
print(f"Parsing {table}") print(f"Parsing {table}")
assert(table in self._valid_text_tables) assert(table in self._valid_text_tables)
# read table # read table
df = pd.read_csv( df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"), os.path.join(self.root, f"{table}.csv"),
usecols=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"], usecols=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"],
dtype={"ICD9_CODE": str, "SHORT_TITLE": str, "LONG_TITLE": str} dtype={"ICD9_CODE": str, "SHORT_TITLE": str, "LONG_TITLE": str}
) )
# drop rows with missing values # drop rows with missing values
df = df.dropna(subset=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"]) df = df.dropna(subset=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"])
# sort by sequence number (i.e., priority) # sort by sequence number (i.e., priority)
df = df.sort_values(["ICD9_CODE"], ascending=True) df = df.sort_values(["ICD9_CODE"], ascending=True)
# print(df.head()) # print(df.head())
self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split') self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split')
# We haven't altered the patients array, just return it. # We haven't altered the patients array, just return it.
return patients return patients
``` ```
%% Cell type:code id:427a42a6-441e-438c-96f4-b38ba82fd192 tags: %% Cell type:code id:427a42a6-441e-438c-96f4-b38ba82fd192 tags:
``` python ``` python
mimic3base = MIMIC3DatasetWrapper( mimic3base = MIMIC3DatasetWrapper(
#root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/", # root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
root=os.path.join(os.getcwd(), DATA_DIR_), root=os.path.join(os.getcwd(), DATA_DIR_),
tables=["D_ICD_DIAGNOSES", "D_ICD_PROCEDURES", "D_ITEMS", "D_LABITEMS", tables=["D_ICD_DIAGNOSES", "D_ICD_PROCEDURES", "D_ITEMS", "D_LABITEMS",
"DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS", "LABEVENTS"], "DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], # "LABEVENTS"],
# map all NDC codes to ATC 3-rd level codes in these tables # map all NDC codes to ATC 3-rd level codes in these tables
# See https://en.wikipedia.org/wiki/Anatomical_Therapeutic_Chemical_Classification_System. # See https://en.wikipedia.org/wiki/Anatomical_Therapeutic_Chemical_Classification_System.
code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
# Slow # Slow
refresh_cache=True, refresh_cache=False,
) )
mimic3base.stat() mimic3base.stat()
mimic3base.info() mimic3base.info()
``` ```
%% Output %% Output
Parsing PATIENTS and ADMISSIONS: 100%|███████████████████████████████████████████████| 100/100 [00:00<00:00, 616.21it/s] Parsing PATIENTS and ADMISSIONS: 100%|███████████████████████████████████████████████| 100/100 [00:00<00:00, 594.73it/s]
Parsing D_ICD_DIAGNOSES Parsing D_ICD_DIAGNOSES
Parsing D_ICD_PROCEDURES Parsing D_ICD_PROCEDURES
Parsing D_ITEMS Parsing D_ITEMS
Parsing D_LABITEMS Parsing D_LABITEMS
Parsing DIAGNOSES_ICD: 100%|████████████████████████████████████████████████████████| 129/129 [00:00<00:00, 5045.37it/s] Parsing DIAGNOSES_ICD: 100%|████████████████████████████████████████████████████████| 129/129 [00:00<00:00, 5116.36it/s]
Parsing PROCEDURES_ICD: 100%|███████████████████████████████████████████████████████| 113/113 [00:00<00:00, 6525.72it/s] Parsing PROCEDURES_ICD: 100%|███████████████████████████████████████████████████████| 113/113 [00:00<00:00, 6220.39it/s]
Parsing PRESCRIPTIONS: 100%|█████████████████████████████████████████████████████████| 122/122 [00:00<00:00, 136.63it/s] Parsing PRESCRIPTIONS: 100%|█████████████████████████████████████████████████████████| 122/122 [00:00<00:00, 137.11it/s]
Parsing LABEVENTS: 100%|██████████████████████████████████████████████████████████████| 129/129 [00:05<00:00, 25.78it/s] Mapping codes: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 253.20it/s]
Mapping codes: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 244.02it/s]
Statistics of base dataset (dev=False): Statistics of base dataset (dev=False):
- Dataset: MIMIC3DatasetWrapper - Dataset: MIMIC3DatasetWrapper
- Number of patients: 100 - Number of patients: 100
- Number of visits: 129 - Number of visits: 129
- Number of visits per patient: 1.2900 - Number of visits per patient: 1.2900
- Number of events per visit in D_ICD_DIAGNOSES: 0.0000 - Number of events per visit in D_ICD_DIAGNOSES: 0.0000
- Number of events per visit in D_ICD_PROCEDURES: 0.0000 - Number of events per visit in D_ICD_PROCEDURES: 0.0000
- Number of events per visit in D_ITEMS: 0.0000 - Number of events per visit in D_ITEMS: 0.0000
- Number of events per visit in D_LABITEMS: 0.0000 - Number of events per visit in D_LABITEMS: 0.0000
- Number of events per visit in DIAGNOSES_ICD: 13.6512 - Number of events per visit in DIAGNOSES_ICD: 13.6512
- Number of events per visit in PROCEDURES_ICD: 3.9225 - Number of events per visit in PROCEDURES_ICD: 3.9225
- Number of events per visit in PRESCRIPTIONS: 115.6667 - Number of events per visit in PRESCRIPTIONS: 115.6667
- Number of events per visit in LABEVENTS: 479.1628
dataset.patients: patient_id -> <Patient> dataset.patients: patient_id -> <Patient>
<Patient> <Patient>
- visits: visit_id -> <Visit> - visits: visit_id -> <Visit>
- other patient-level info - other patient-level info
<Visit> <Visit>
- event_list_dict: table_name -> List[Event] - event_list_dict: table_name -> List[Event]
- other visit-level info - other visit-level info
<Event> <Event>
- code: str - code: str
- other event-level info - other event-level info
%% Cell type:code id:4cef4476-6bc8-4791-b246-4c329a80a2e4 tags: %% Cell type:code id:4cef4476-6bc8-4791-b246-4c329a80a2e4 tags:
``` python ``` python
table_names = mimic3base.get_all_tables() table_names = mimic3base.get_all_tables()
print(table_names) print(table_names)
for t in table_names: for t in table_names:
d = mimic3base.get_text_dict(t) d = mimic3base.get_text_dict(t)
print(f"Table: {t}") print(f"Table: {t}")
print(d['data'][:5]) print(d['data'][:5])
print('\n\n') print('\n\n')
for t in table_names:
d = mimic3base.get_text_dict(t)
d = d['data']
lut = {record[0]: record[1:] for record in d}
mimic3base.set_text_lut(t, lut)
``` ```
%% Output %% Output
dict_keys(['D_ICD_DIAGNOSES', 'D_ITEMS', 'D_ICD_PROCEDURES', 'D_LABITEMS']) ['D_ICD_DIAGNOSES', 'D_ITEMS', 'D_ICD_PROCEDURES', 'D_LABITEMS']
Table: D_ICD_DIAGNOSES Table: D_ICD_DIAGNOSES
[['0010', 'Cholera d/t vib cholerae', 'Cholera due to vibrio cholerae'], ['0011', 'Cholera d/t vib el tor', 'Cholera due to vibrio cholerae el tor'], ['0019', 'Cholera NOS', 'Cholera, unspecified'], ['0020', 'Typhoid fever', 'Typhoid fever'], ['0021', 'Paratyphoid fever a', 'Paratyphoid fever A']] [['0010', 'Cholera d/t vib cholerae', 'Cholera due to vibrio cholerae'], ['0011', 'Cholera d/t vib el tor', 'Cholera due to vibrio cholerae el tor'], ['0019', 'Cholera NOS', 'Cholera, unspecified'], ['0020', 'Typhoid fever', 'Typhoid fever'], ['0021', 'Paratyphoid fever a', 'Paratyphoid fever A']]
Table: D_ITEMS Table: D_ITEMS
[['1126', 'Art.pH', 'ABG'], ['1127', 'WBC (4-11,000)', 'Hematology'], ['1520', 'ACT', 'Coags'], ['1521', 'Albumin', 'Chemistry'], ['1522', 'Calcium', 'Chemistry']] [['1126', 'Art.pH', 'ABG'], ['1127', 'WBC (4-11,000)', 'Hematology'], ['1520', 'ACT', 'Coags'], ['1521', 'Albumin', 'Chemistry'], ['1522', 'Calcium', 'Chemistry']]
Table: D_ICD_PROCEDURES Table: D_ICD_PROCEDURES
[['0001', 'Ther ult head & neck ves', 'Therapeutic ultrasound of vessels of head and neck'], ['0002', 'Ther ultrasound of heart', 'Therapeutic ultrasound of heart'], ['0003', 'Ther ult peripheral ves', 'Therapeutic ultrasound of peripheral vascular vessels'], ['0009', 'Other therapeutic ultsnd', 'Other therapeutic ultrasound'], ['0010', 'Implant chemothera agent', 'Implantation of chemotherapeutic agent']] [['0001', 'Ther ult head & neck ves', 'Therapeutic ultrasound of vessels of head and neck'], ['0002', 'Ther ultrasound of heart', 'Therapeutic ultrasound of heart'], ['0003', 'Ther ult peripheral ves', 'Therapeutic ultrasound of peripheral vascular vessels'], ['0009', 'Other therapeutic ultsnd', 'Other therapeutic ultrasound'], ['0010', 'Implant chemothera agent', 'Implantation of chemotherapeutic agent']]
Table: D_LABITEMS Table: D_LABITEMS
[['50800', 'SPECIMEN TYPE', 'BLOOD', 'BLOOD GAS'], ['50801', 'Alveolar-arterial Gradient', 'Blood', 'Blood Gas'], ['50802', 'Base Excess', 'Blood', 'Blood Gas'], ['50803', 'Calculated Bicarbonate, Whole Blood', 'Blood', 'Blood Gas'], ['50804', 'Calculated Total CO2', 'Blood', 'Blood Gas']] [['50800', 'SPECIMEN TYPE', 'BLOOD', 'BLOOD GAS'], ['50801', 'Alveolar-arterial Gradient', 'Blood', 'Blood Gas'], ['50802', 'Base Excess', 'Blood', 'Blood Gas'], ['50803', 'Calculated Bicarbonate, Whole Blood', 'Blood', 'Blood Gas'], ['50804', 'Calculated Total CO2', 'Blood', 'Blood Gas']]
%% Cell type:markdown id:b02b8e28-23ac-4ac1-a9b6-ae546e28c905 tags: %% Cell type:markdown id:b02b8e28-23ac-4ac1-a9b6-ae546e28c905 tags:
**Tasks** **Tasks**
Declare tasks for 2 of the 5 prediction tasks specified in the paper. We will create dataloaders for each task that contain the ICD codes and the raw text for each (patient, visit). Declare tasks for 2 of the 5 prediction tasks specified in the paper. We will create dataloaders for each task that contain the ICD codes and the raw text for each (patient, visit).
%% Cell type:code id:08692aad-db67-487b-9b26-dbde0ab6adce tags: %% Cell type:code id:08692aad-db67-487b-9b26-dbde0ab6adce tags:
``` python ``` python
# The original authors tackled 5 tasks # The original authors tackled 5 tasks
# 1. readmission # 1. readmission
# 2. mortality # 2. mortality
# 3. an ICU stay exceeding three days # 3. an ICU stay exceeding three days
# 4. an ICU stay exceeding seven days # 4. an ICU stay exceeding seven days
# 5. diagnosis prediction # 5. diagnosis prediction
def readmission_pred_task(patient): def readmission_pred_task(patient):
""" """
patient is a <pyhealth.data.Patient> object patient is a <pyhealth.data.Patient> object
""" """
samples = [] samples = []
# loop over all visits but the last one # loop over all visits but the last one
for i in range(len(patient) - 1): for i in range(len(patient) - 1):
# visit and next_visit are both <pyhealth.data.Visit> objects # visit and next_visit are both <pyhealth.data.Visit> objects
# there are no vists.attr_dict keys
visit = patient[i] visit = patient[i]
next_visit = patient[i + 1] next_visit = patient[i + 1]
# step 1: define the mortality_label # step 1: define the mortality_label
if next_visit.discharge_status not in [0, 1]: if next_visit.discharge_status not in [0, 1]:
mortality_label = 0 mortality_label = 0
else: else:
mortality_label = int(next_visit.discharge_status) mortality_label = int(next_visit.discharge_status)
# step 2: get code-based feature information # step 2: get code-based feature information
conditions = visit.get_code_list(table="DIAGNOSES_ICD") conditions = visit.get_code_list(table="DIAGNOSES_ICD")
procedures = visit.get_code_list(table="PROCEDURES_ICD") procedures = visit.get_code_list(table="PROCEDURES_ICD")
drugs = visit.get_code_list(table="PRESCRIPTIONS") # drugs = visit.get_code_list(table="PRESCRIPTIONS")
labevents = visit.get_code_list(table="LABEVENTS") # TODO(botelho3) - add this datasource back in once we have full MIMIC-III dataset.
# labevents = visit.get_code_list(table="LABEVENTS")
# step 3: exclusion criteria: visits without condition, procedure, or drug # step 3: exclusion criteria: visits without condition, procedure, or drug
if len(conditions) * len(procedures) * len(drugs) == 0: continue if len(conditions) * len(procedures) * len(labevents) == 0: continue
# step 3.5: build text lists from the ICD codes
d_diag = mimic3base._text_luts.get("D_DIAGNOSES_ICD").get()
d_proc = mimic3base._text_luts.get("D_PROCEDURES_ICD").get()
# Index 0 is shortname, index 1 is longname.
conditions_text = [d_diag.get(cond)[1] for cond in conditions]
procedutes_text = [d_proc.get(proc)[1] for proc in procedures]
# labevents_text =
# step 4: assemble the samples # step 4: assemble the samples
samples.append( samples.append(
{ {
"visit_id": visit.visit_id, "visit_id": visit.visit_id,
"patient_id": patient.patient_id, "patient_id": patient.patient_id,
# the following keys can be the "feature_keys" or "label_key" for initializing downstream ML model # the following keys can be the "feature_keys" or "label_key" for initializing downstream ML model
"conditions": conditions, "conditions": conditions,
"procedures": procedures, "procedures": procedures,
"drugs": drugs, "conditions_text": conditions_text,
"procedures_text": procedures_text,
# "drugs": drugs,
# "labevents": labevents,
# "labevents_text": labevents_text
"label": mortality_label, "label": mortality_label,
} }
) )
return samples return samples
def mortality_pred_task(patient): def mortality_pred_task(patient):
""" """
patient is a <pyhealth.data.Patient> object patient is a <pyhealth.data.Patient> object
""" """
samples = [] samples = []
# loop over all visits but the last one # loop over all visits but the last one
for i in range(len(patient) - 1): for i in range(len(patient) - 1):
# visit and next_visit are both <pyhealth.data.Visit> objects # visit and next_visit are both <pyhealth.data.Visit> objects
visit = patient[i] visit = patient[i]
next_visit = patient[i + 1] next_visit = patient[i + 1]
# step 1: define the mortality_label # step 1: define the mortality_label
if next_visit.discharge_status not in [0, 1]: if next_visit.discharge_status not in [0, 1]:
mortality_label = 0 mortality_label = 0
else: else:
mortality_label = int(next_visit.discharge_status) mortality_label = int(next_visit.discharge_status)
# step 2: get code-based feature information # step 2: get code-based feature information
conditions = visit.get_code_list(table="DIAGNOSES_ICD") conditions = visit.get_code_list(table="DIAGNOSES_ICD")
procedures = visit.get_code_list(table="PROCEDURES_ICD") procedures = visit.get_code_list(table="PROCEDURES_ICD")
drugs = visit.get_code_list(table="PRESCRIPTIONS") drugs = visit.get_code_list(table="PRESCRIPTIONS")
labevents = visit.get_code_list(table="LABEVENTS")
# step 3: exclusion criteria: visits without condition, procedure, or drug # step 3: exclusion criteria: visits without condition, procedure, or drug
if len(conditions) * len(procedures) * len(drugs) == 0: continue if len(conditions) * len(procedures) * len(drugs) == 0: continue
# step 4: assemble the samples # step 4: assemble the samples
samples.append( samples.append(
{ {
"visit_id": visit.visit_id, "visit_id": visit.visit_id,
"patient_id": patient.patient_id, "patient_id": patient.patient_id,
# the following keys can be the "feature_keys" or "label_key" for initializing downstream ML model # the following keys can be the "feature_keys" or "label_key" for initializing downstream ML model
"conditions": conditions, "conditions": conditions,
"procedures": procedures, "procedures": procedures,
"drugs": drugs, "drugs": drugs,
"label": mortality_label, "label": mortality_label,
} }
) )
return samples return samples
# mimic3sample = mimic3base.set_task(task_fn=drug_recommendation_mimic3_fn) # use default task # mimic3sample = mimic3base.set_task(task_fn=drug_recommendation_mimic3_fn) # use default task
# train_ds, val_ds, test_ds = split_by_patient(mimic3sample, [0.8, 0.1, 0.1]) # train_ds, val_ds, test_ds = split_by_patient(mimic3sample, [0.8, 0.1, 0.1])
``` ```
%% Cell type:code id:96d047ff-6412-462c-9294-a5ba2d3bdba2 tags: %% Cell type:code id:96d047ff-6412-462c-9294-a5ba2d3bdba2 tags:
``` python ``` python
re_dataset = dataset.set_task(readmission_pred_task) readm_dataset = mimic3base.set_task(readmission_pred_task)
re_dataset.stat() readm_dataset.stat()
readm_dataset.samples[0]
``` ```
%% Output
Generating samples for readmission_pred_task: 44%|███████████████▍ | 44/100 [00:00<00:00, 386.68it/s]
visit num_events: 598
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 252
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 414
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 276
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 314
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 774
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 1723
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 635
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 856
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 288
visit tables: ['DIAGNOSES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 252
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 393
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 641
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 310
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 361
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 178
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 385
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
Generating samples for readmission_pred_task: 100%|██████████████████████████████████| 100/100 [00:00<00:00, 328.61it/s]
visit num_events: 908
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 498
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 1178
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 822
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 185
visit tables: ['DIAGNOSES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 1496
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 177
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 1036
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 130
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 369
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 268
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
visit num_events: 136
visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']
visit attrs
{}
Statistics of sample dataset:
- Dataset: MIMIC3DatasetWrapper
- Task: readmission_pred_task
- Number of samples: 27
- Number of patients: 13
- Number of visits: 27
- Number of visits per patient: 2.0769
- conditions:
- Number of conditions per sample: 1.0000
- Number of unique conditions: 17
- Distribution of conditions (Top-10): [('0389', 8), ('486', 3), ('03849', 2), ('5715', 1), ('41071', 1), ('5750', 1), ('51884', 1), ('20280', 1), ('53642', 1), ('51881', 1)]
- procedures:
- Number of procedures per sample: 3.4074
- Number of unique procedures: 42
- Distribution of procedures (Top-10): [('966', 15), ('3893', 14), ('9671', 5), ('9672', 4), ('9604', 4), ('4513', 3), ('3891', 3), ('9915', 3), ('9904', 3), ('3995', 2)]
- drugs:
- Number of drugs per sample: 40.7037
- Number of unique drugs: 112
- Distribution of drugs (Top-10): [('B05X', 27), ('B01A', 25), ('N02A', 25), ('A02B', 24), ('A06A', 24), ('N02B', 23), ('J01X', 21), ('A12C', 20), ('A12B', 20), ('V04C', 19)]
- label:
- Number of label per sample: 1.0000
- Number of unique label: 2
- Distribution of label (Top-10): [(0, 26), (1, 1)]
{'visit_id': '122098',
'patient_id': '10059',
'conditions': '5715',
'procedures': ['4513',
'3891',
'3893',
'9672',
'9915',
'3895',
'3995',
'5491'],
'drugs': ['A12C',
'B05X',
'V04C',
'A02B',
'J01F',
'N05C',
'N01A',
'A12A',
'H01C',
'B02B',
'V06D',
'A10A',
'J01M',
'C01C',
'B01A',
'B05B',
'C05B',
'J01X',
'A06A',
'B05A',
'H02A',
'N02A',
'A04A'],
'label': 0}
%% Cell type:code id:70586b66-a814-4898-892b-d1bdd339ce7b tags: %% Cell type:code id:70586b66-a814-4898-892b-d1bdd339ce7b tags:
``` python ``` python
mor_dataset = dataset.set_task(mortality_pred_task) mor_dataset = mimic3base.set_task(mortality_pred_task)
mor_dataset.stat() mor_dataset.stat()
mor_dataset.samples[0]
```
%% Cell type:markdown id:66a0281a-f95b-4991-abd5-5824fc37c520 tags:
### DataLoaders and Collate
%% Cell type:code id:924c8e3c-0482-4688-b2d6-d83d490060a9 tags:
``` python
# def collate_fn(data):
# """
# Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
# sequences to the sample shape (max # visits, len(freq_codes)). The padding infomation
# is stored in `mask`.
# Arguments:
# data: a list of samples fetched from `CustomDataset`
# Outputs:
# x: a tensor of shape (# patiens, max # visits, len(freq_codes)) of type torch.float
# masks: a tensor of shape (# patiens, max # visits) of type torch.bool
# y: a tensor of shape (# patiens) of type torch.float
# Note that you can obtains the list of diagnosis codes and the list of hf labels
# using: `sequences, labels = zip(*data)`
# """
# sequences, labels = zip(*data)
# y = torch.tensor(labels, dtype=torch.float)
# num_patients = len(sequences)
# num_visits = [len(patient) for patient in sequences]
# max_num_visits = max(num_visits)
# x = torch.zeros((num_patients, max_num_visits, len(freq_codes)), dtype=torch.float)
# for i_patient, patient in enumerate(sequences):
# kMaxVisits = len(patient)
# for j_visit, visit in enumerate(patient):
# l = len(visit)
# for code in visit:
# """
# TODO: 1. check if code is in freq_codes;
# 2. obtain the code index using code2idx;
# 3. set the correspoindg element in x to 1.
# """
# try:
# idx = code2idx[code]
# x[i_patient, j_visit, idx] = 1
# except KeyError as e:
# pass
# masks = torch.sum(x, dim=-1) > 0
# return x, masks, y
``` ```
%% Cell type:code id:1056b60a-5861-4e47-b694-891053cc8470 tags: %% Cell type:code id:1056b60a-5861-4e47-b694-891053cc8470 tags:
``` python ``` python
# create dataloaders (torch.data.DataLoader) # create dataloaders (torch.data.DataLoader)
train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True) train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False) val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)
test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
# from torch.utils.data import DataLoader
# def load_data(train_dataset, val_dataset, collate_fn):
# '''
# TODO: Implement this function to return the data loader for train and validation dataset.
# Set batchsize to 32. Set `shuffle=True` only for train dataloader.
# Arguments:
# train dataset: train dataset of type `CustomDataset`
# val dataset: validation dataset of type `CustomDataset`
# collate_fn: collate function
# Outputs:
# train_loader, val_loader: train and validation dataloaders
# Note that you need to pass the collate function to the data loader `collate_fn()`.
# '''
# batch_size = 32
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
# return train_loader, val_loader
```
%% Cell type:code id:035fccee-d5da-428c-a54a-0678a2d19b5e tags:
``` python
# Verify DataLoader properties.
# from torch.utils.data import DataLoader
# loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
# loader_iter = iter(loader)
# x, masks, y = next(loader_iter)
# assert x.dtype == torch.float
# assert y.dtype == torch.float
# assert masks.dtype == torch.bool
# assert x.shape == (10, 3, 105)
# assert y.shape == (10,)
# assert masks.shape == (10, 3)
# assert x[0][0].sum() == 9
# assert masks[0].sum() == 2
``` ```
%% Cell type:markdown id:0f9b59b9-4bdb-44c4-b155-0a2a2d0f8d07 tags: %% Cell type:markdown id:0f9b59b9-4bdb-44c4-b155-0a2a2d0f8d07 tags:
### Embed MIMIC III Data using BERT ### Embed MIMIC III Data using BERT
%% Cell type:code id:1f5e0b1b-d75c-42ee-8d98-879f14b04606 tags: %% Cell type:code id:1f5e0b1b-d75c-42ee-8d98-879f14b04606 tags:
``` python ``` python
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment