diff --git a/src/cs598_dlh_final_project.ipynb b/src/cs598_dlh_final_project.ipynb index c4d5a68f77f5bade6cc4b557ca8afa61bede4374..1cb485799937731841e3bbc3a375323dbb0ea1ab 100644 --- a/src/cs598_dlh_final_project.ipynb +++ b/src/cs598_dlh_final_project.ipynb @@ -1,13 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "3225717c-458b-45cb-9e8b-89ecd03b9d07", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "id": "7ccba6af-1867-491d-b9d0-d99965738c64", @@ -18,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 13, "id": "5da368cd-3045-4ec0-86fb-2ce15b5d1b92", "metadata": {}, "outputs": [], @@ -26,14 +18,18 @@ "# General includes.\n", "import os\n", "\n", + "# Typing includes.\n", + "from typing import Dict, List, Optional\n", + "\n", "# Numerical includes.\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from torch import nn\n", "\n", - "# Pyhealth includes.\n", - "from pyhealth.datasets import MIMIC3Dataset" + "# pyHealth includes.\n", + "from pyhealth.datasets import MIMIC3Dataset\n", + "from pyhealth.data import Patient, Visit, Event" ] }, { @@ -57,12 +53,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 22, "id": "ebef3528-0396-459a-8112-b272089f5d67", "metadata": {}, "outputs": [], "source": [ - "USE_GPU_ = False" + "USE_GPU_ = False\n", + "GPU_STRING_ = 'cuda'\n", + "DATA_DIR_ = '../data_input_path/mimic'" ] }, { @@ -70,7 +68,7 @@ "id": "df0f8271-44b2-44a2-b7cc-cd7339a70e87", "metadata": {}, "source": [ - "## Preprocessing" + "# Preprocessing" ] }, { @@ -176,9 +174,9 @@ "\n", "# If you have a GPU, put everything on cuda\n", "if USE_GPU_:\n", - " tokens_tensor = tokens_tensor.to('cuda')\n", - " segments_tensors = segments_tensors.to('cuda')\n", - " model.to('cuda')\n", + " tokens_tensor = tokens_tensor.to(GPU_STRING_)\n", + " segments_tensors = segments_tensors.to(GPU_STRING_)\n", + " model.to(GPU_STRING_)\n", "\n", "# Predict all tokens\n", "with torch.no_grad():\n", @@ -227,16 +225,302 @@ "id": "8f0d5fa9-226e-4612-b28f-ed24de16399a", "metadata": {}, "source": [ - "### Load " + "### Load MIMIC III Data" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 29, + "id": "b881e548-4d27-4725-8c47-b2612157929e", + "metadata": {}, + "outputs": [ + { + "ename": "IndentationError", + "evalue": "expected an indented block (845818540.py, line 69)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m File \u001b[0;32m\"/var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_27199/845818540.py\"\u001b[0;36m, line \u001b[0;32m69\u001b[0m\n\u001b[0;31m def parse_d_items(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m expected an indented block\n" + ] + } + ], + "source": [ + "class MIMIC3DatasetWrapper(MIMIC3Dataset):\n", + " ''' Add extra tables to the MIMIC III dataset.\n", + " \n", + " Some of the tables we need like \"D_ICD_DIAGNOSES\", \"D_ITEMS\", \"D_ICD_PROCEDURES\"\n", + " are not supported out of the box. \n", + " \n", + " This class defines parsing methods to extract text data from these extra tables.\n", + " The text data is generally joined on the PATIENTID, HADMID, ITEMID to match the\n", + " pyHealth Vists class representation.\n", + " '''\n", + " \n", + " # Skip init and defer to base class.\n", + " \n", + " # Note the name has to match the table name exactly.\n", + " # See https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/datasets/mimic3.py#L71.\n", + " def parse_d_icd_diagnoses(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: \n", + " # TODO(botelho3) fill this in to join the text descriptions to the visit.\n", + " return patients\n", + "# \"\"\"Helper function which parses DIAGNOSES_ICD table.\n", + "# Will be called in `self.parse_tables()`\n", + "# Docs:\n", + "# - DIAGNOSES_ICD: https://mimic.mit.edu/docs/iii/tables/diagnoses_icd/\n", + "# Args:\n", + "# patients: a dict of `Patient` objects indexed by patient_id.\n", + "# Returns:\n", + "# The updated patients dict.\n", + "# Note:\n", + "# MIMIC-III does not provide specific timestamps in DIAGNOSES_ICD\n", + "# table, so we set it to None.\n", + "# \"\"\"\n", + "# table = \"DIAGNOSES_ICD\"\n", + "# # read table\n", + "# df = pd.read_csv(\n", + "# os.path.join(self.root, f\"{table}.csv\"),\n", + "# dtype={\"SUBJECT_ID\": str, \"HADM_ID\": str, \"ICD9_CODE\": str},\n", + "# )\n", + "# # drop records of the other patients\n", + "# df = df[df[\"SUBJECT_ID\"].isin(patients.keys())]\n", + "# # drop rows with missing values\n", + "# df = df.dropna(subset=[\"SUBJECT_ID\", \"HADM_ID\", \"ICD9_CODE\"])\n", + "# # sort by sequence number (i.e., priority)\n", + "# df = df.sort_values([\"SUBJECT_ID\", \"HADM_ID\", \"SEQ_NUM\"], ascending=True)\n", + "# # group by patient and visit\n", + "# group_df = df.groupby(\"SUBJECT_ID\")\n", + "\n", + "# # parallel unit of diagnosis (per patient)\n", + "# def diagnosis_unit(p_id, p_info):\n", + "# events = []\n", + "# for v_id, v_info in p_info.groupby(\"HADM_ID\"):\n", + "# for code in v_info[\"ICD9_CODE\"]:\n", + "# event = Event(\n", + "# code=code,\n", + "# table=table,\n", + "# vocabulary=\"ICD9CM\",\n", + "# visit_id=v_id,\n", + "# patient_id=p_id,\n", + "# )\n", + "# events.append(event)\n", + "# return events\n", + "\n", + "# # parallel apply\n", + "# group_df = group_df.parallel_apply(\n", + "# lambda x: diagnosis_unit(x.SUBJECT_ID.unique()[0], x)\n", + "# )\n", + "\n", + "# # summarize the results\n", + "# patients = self._add_events_to_patient_dict(patients, group_df)\n", + "# return patients\n", + " \n", + " def parse_d_items(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: \n", + " # TODO(botelho3) fill this in to join the text descriptions to the visit.\n", + " return patients\n", + " \n", + " def parse_d_icd_procedures(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: \n", + " # TODO(botelho3) fill this in to join the text descriptions to the visit.\n", + " return patients\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 28, "id": "427a42a6-441e-438c-96f4-b38ba82fd192", "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Parsing PATIENTS and ADMISSIONS: 100%|███████████████████████████████████████████████| 100/100 [00:00<00:00, 540.68it/s]\n", + "Parsing PROCEDURES_ICD: 100%|███████████████████████████████████████████████████████| 113/113 [00:00<00:00, 5934.54it/s]\n", + "Parsing PRESCRIPTIONS: 100%|██████████████████████████████████████████████████████████| 122/122 [00:01<00:00, 99.08it/s]\n", + "Parsing LABEVENTS: 100%|██████████████████████████████████████████████████████████████| 129/129 [00:06<00:00, 21.36it/s]\n", + "Mapping codes: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 214.50it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Statistics of base dataset (dev=False):\n", + "\t- Dataset: MIMIC3DatasetWrapper\n", + "\t- Number of patients: 100\n", + "\t- Number of visits: 129\n", + "\t- Number of visits per patient: 1.2900\n", + "\t- Number of events per visit in D_ITEMS: 0.0000\n", + "\t- Number of events per visit in D_ICD_PROCEDURES: 0.0000\n", + "\t- Number of events per visit in PROCEDURES_ICD: 3.9225\n", + "\t- Number of events per visit in PRESCRIPTIONS: 115.6667\n", + "\t- Number of events per visit in LABEVENTS: 479.1628\n", + "\n", + "\n", + "dataset.patients: patient_id -> <Patient>\n", + "\n", + "<Patient>\n", + " - visits: visit_id -> <Visit> \n", + " - other patient-level info\n", + " \n", + " <Visit>\n", + " - event_list_dict: table_name -> List[Event]\n", + " - other visit-level info\n", + " \n", + " <Event>\n", + " - code: str\n", + " - other event-level info\n", + "\n" + ] + } + ], + "source": [ + "mimic3base = MIMIC3DatasetWrapper(\n", + " #root=\"https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/\",\n", + " root=os.path.join(os.getcwd(), DATA_DIR_),\n", + " tables=[\"D_ITEMS\", \"D_ICD_PROCEDURES\", \"PROCEDURES_ICD\", \"PRESCRIPTIONS\", \"LABEVENTS\",], # \"D_ICD_DIAGNOSES\", \"DIAGNOSES_ICD\"\n", + " # map all NDC codes to ATC 3-rd level codes in these tables\n", + " # See https://en.wikipedia.org/wiki/Anatomical_Therapeutic_Chemical_Classification_System.\n", + " code_mapping={\"NDC\": (\"ATC\", {\"target_kwargs\": {\"level\": 3}})},\n", + ")\n", + "\n", + "mimic3base.stat()\n", + "mimic3base.info()\n" + ] + }, + { + "cell_type": "markdown", + "id": "b02b8e28-23ac-4ac1-a9b6-ae546e28c905", + "metadata": {}, + "source": [ + "**Tasks**\n", + "\n", + "Declare tasks for 2 of the 5 prediction tasks specified in the paper. We will create dataloaders for each task that contain the ICD codes and the raw text for each (patient, visit)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08692aad-db67-487b-9b26-dbde0ab6adce", + "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# The original authors tackled 5 tasks\n", + "# 1. readmission\n", + "# 2. mortality\n", + "# 3. an ICU stay exceeding three days\n", + "# 4. an ICU stay exceeding seven days\n", + "# 5. diagnosis prediction\n", + "\n", + "def readmission_pred_task(patient):\n", + " \"\"\"\n", + " patient is a <pyhealth.data.Patient> object\n", + " \"\"\"\n", + " samples = []\n", + "\n", + " # loop over all visits but the last one\n", + " for i in range(len(patient) - 1):\n", + "\n", + " # visit and next_visit are both <pyhealth.data.Visit> objects\n", + " visit = patient[i]\n", + " next_visit = patient[i + 1]\n", + "\n", + " # step 1: define the mortality_label\n", + " if next_visit.discharge_status not in [0, 1]:\n", + " mortality_label = 0\n", + " else:\n", + " mortality_label = int(next_visit.discharge_status)\n", + "\n", + " # step 2: get code-based feature information\n", + " conditions = visit.get_code_list(table=\"DIAGNOSES_ICD\")\n", + " procedures = visit.get_code_list(table=\"PROCEDURES_ICD\")\n", + " drugs = visit.get_code_list(table=\"PRESCRIPTIONS\")\n", + " labevents = visit.get_code_list(table=\"LABEVENTS\")\n", + "\n", + " # step 3: exclusion criteria: visits without condition, procedure, or drug\n", + " if len(conditions) * len(procedures) * len(drugs) == 0: continue\n", + " \n", + " # step 4: assemble the samples\n", + " samples.append(\n", + " {\n", + " \"visit_id\": visit.visit_id,\n", + " \"patient_id\": patient.patient_id,\n", + " # the following keys can be the \"feature_keys\" or \"label_key\" for initializing downstream ML model\n", + " \"conditions\": conditions,\n", + " \"procedures\": procedures,\n", + " \"drugs\": drugs,\n", + " \"label\": mortality_label,\n", + " }\n", + " )\n", + " return samples\n", + "\n", + "def mortality_pred_task(patient):\n", + " \"\"\"\n", + " patient is a <pyhealth.data.Patient> object\n", + " \"\"\"\n", + " samples = []\n", + "\n", + " # loop over all visits but the last one\n", + " for i in range(len(patient) - 1):\n", + "\n", + " # visit and next_visit are both <pyhealth.data.Visit> objects\n", + " visit = patient[i]\n", + " next_visit = patient[i + 1]\n", + "\n", + " # step 1: define the mortality_label\n", + " if next_visit.discharge_status not in [0, 1]:\n", + " mortality_label = 0\n", + " else:\n", + " mortality_label = int(next_visit.discharge_status)\n", + "\n", + " # step 2: get code-based feature information\n", + " conditions = visit.get_code_list(table=\"DIAGNOSES_ICD\")\n", + " procedures = visit.get_code_list(table=\"PROCEDURES_ICD\")\n", + " drugs = visit.get_code_list(table=\"PRESCRIPTIONS\")\n", + "\n", + " # step 3: exclusion criteria: visits without condition, procedure, or drug\n", + " if len(conditions) * len(procedures) * len(drugs) == 0: continue\n", + " \n", + " # step 4: assemble the samples\n", + " samples.append(\n", + " {\n", + " \"visit_id\": visit.visit_id,\n", + " \"patient_id\": patient.patient_id,\n", + " # the following keys can be the \"feature_keys\" or \"label_key\" for initializing downstream ML model\n", + " \"conditions\": conditions,\n", + " \"procedures\": procedures,\n", + " \"drugs\": drugs,\n", + " \"label\": mortality_label,\n", + " }\n", + " )\n", + " return samples\n", + "\n", + "\n", + "# mimic3sample = mimic3base.set_task(task_fn=drug_recommendation_mimic3_fn) # use default task\n", + "# train_ds, val_ds, test_ds = split_by_patient(mimic3sample, [0.8, 0.1, 0.1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96d047ff-6412-462c-9294-a5ba2d3bdba2", + "metadata": {}, + "outputs": [], + "source": [ + "re_dataset = dataset.set_task(readmission_pred_task)\n", + "re_dataset.stat()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70586b66-a814-4898-892b-d1bdd339ce7b", + "metadata": {}, + "outputs": [], + "source": [ + "mor_dataset = dataset.set_task(mortality_pred_task)\n", + "mor_dataset.stat()" + ] }, { "cell_type": "code", @@ -244,6 +528,27 @@ "id": "1056b60a-5861-4e47-b694-891053cc8470", "metadata": {}, "outputs": [], + "source": [ + "# create dataloaders (torch.data.DataLoader)\n", + "train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)\n", + "val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)\n", + "test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)\n" + ] + }, + { + "cell_type": "markdown", + "id": "0f9b59b9-4bdb-44c4-b155-0a2a2d0f8d07", + "metadata": {}, + "source": [ + "### Embed MIMIC III Data using BERT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f5e0b1b-d75c-42ee-8d98-879f14b04606", + "metadata": {}, + "outputs": [], "source": [] } ],