diff --git a/src/cs598_dlh_final_project.ipynb b/src/cs598_dlh_final_project.ipynb
index c4d5a68f77f5bade6cc4b557ca8afa61bede4374..1cb485799937731841e3bbc3a375323dbb0ea1ab 100644
--- a/src/cs598_dlh_final_project.ipynb
+++ b/src/cs598_dlh_final_project.ipynb
@@ -1,13 +1,5 @@
 {
  "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "3225717c-458b-45cb-9e8b-89ecd03b9d07",
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
   {
    "cell_type": "markdown",
    "id": "7ccba6af-1867-491d-b9d0-d99965738c64",
@@ -18,7 +10,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 13,
    "id": "5da368cd-3045-4ec0-86fb-2ce15b5d1b92",
    "metadata": {},
    "outputs": [],
@@ -26,14 +18,18 @@
     "# General includes.\n",
     "import os\n",
     "\n",
+    "# Typing includes.\n",
+    "from typing import Dict, List, Optional\n",
+    "\n",
     "# Numerical includes.\n",
     "import numpy as np\n",
     "import pandas as pd\n",
     "import torch\n",
     "from torch import nn\n",
     "\n",
-    "# Pyhealth includes.\n",
-    "from pyhealth.datasets import MIMIC3Dataset"
+    "# pyHealth includes.\n",
+    "from pyhealth.datasets import MIMIC3Dataset\n",
+    "from pyhealth.data import Patient, Visit, Event"
    ]
   },
   {
@@ -57,12 +53,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 22,
    "id": "ebef3528-0396-459a-8112-b272089f5d67",
    "metadata": {},
    "outputs": [],
    "source": [
-    "USE_GPU_ = False"
+    "USE_GPU_ = False\n",
+    "GPU_STRING_ = 'cuda'\n",
+    "DATA_DIR_ = '../data_input_path/mimic'"
    ]
   },
   {
@@ -70,7 +68,7 @@
    "id": "df0f8271-44b2-44a2-b7cc-cd7339a70e87",
    "metadata": {},
    "source": [
-    "## Preprocessing"
+    "# Preprocessing"
    ]
   },
   {
@@ -176,9 +174,9 @@
     "\n",
     "# If you have a GPU, put everything on cuda\n",
     "if USE_GPU_:\n",
-    "    tokens_tensor = tokens_tensor.to('cuda')\n",
-    "    segments_tensors = segments_tensors.to('cuda')\n",
-    "    model.to('cuda')\n",
+    "    tokens_tensor = tokens_tensor.to(GPU_STRING_)\n",
+    "    segments_tensors = segments_tensors.to(GPU_STRING_)\n",
+    "    model.to(GPU_STRING_)\n",
     "\n",
     "# Predict all tokens\n",
     "with torch.no_grad():\n",
@@ -227,16 +225,302 @@
    "id": "8f0d5fa9-226e-4612-b28f-ed24de16399a",
    "metadata": {},
    "source": [
-    "### Load "
+    "### Load MIMIC III Data"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 29,
+   "id": "b881e548-4d27-4725-8c47-b2612157929e",
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "IndentationError",
+     "evalue": "expected an indented block (845818540.py, line 69)",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;36m  File \u001b[0;32m\"/var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_27199/845818540.py\"\u001b[0;36m, line \u001b[0;32m69\u001b[0m\n\u001b[0;31m    def parse_d_items(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:\u001b[0m\n\u001b[0m      ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m expected an indented block\n"
+     ]
+    }
+   ],
+   "source": [
+    "class MIMIC3DatasetWrapper(MIMIC3Dataset):\n",
+    "    ''' Add extra tables to the MIMIC III dataset.\n",
+    "    \n",
+    "      Some of the tables we need like \"D_ICD_DIAGNOSES\", \"D_ITEMS\", \"D_ICD_PROCEDURES\"\n",
+    "      are not supported out of the box. \n",
+    "      \n",
+    "      This class defines parsing methods to extract text data from these extra tables.\n",
+    "      The text data is generally joined on the PATIENTID, HADMID, ITEMID to match the\n",
+    "      pyHealth Vists class representation.\n",
+    "    '''\n",
+    "    \n",
+    "    # Skip init and defer to base class.\n",
+    "    \n",
+    "    # Note the name has to match the table name exactly.\n",
+    "    # See https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/datasets/mimic3.py#L71.\n",
+    "    def parse_d_icd_diagnoses(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: \n",
+    "        # TODO(botelho3) fill this in to join the text descriptions to the visit.\n",
+    "        return patients\n",
+    "#                 \"\"\"Helper function which parses DIAGNOSES_ICD table.\n",
+    "#         Will be called in `self.parse_tables()`\n",
+    "#         Docs:\n",
+    "#             - DIAGNOSES_ICD: https://mimic.mit.edu/docs/iii/tables/diagnoses_icd/\n",
+    "#         Args:\n",
+    "#             patients: a dict of `Patient` objects indexed by patient_id.\n",
+    "#         Returns:\n",
+    "#             The updated patients dict.\n",
+    "#         Note:\n",
+    "#             MIMIC-III does not provide specific timestamps in DIAGNOSES_ICD\n",
+    "#                 table, so we set it to None.\n",
+    "#         \"\"\"\n",
+    "#         table = \"DIAGNOSES_ICD\"\n",
+    "#         # read table\n",
+    "#         df = pd.read_csv(\n",
+    "#             os.path.join(self.root, f\"{table}.csv\"),\n",
+    "#             dtype={\"SUBJECT_ID\": str, \"HADM_ID\": str, \"ICD9_CODE\": str},\n",
+    "#         )\n",
+    "#         # drop records of the other patients\n",
+    "#         df = df[df[\"SUBJECT_ID\"].isin(patients.keys())]\n",
+    "#         # drop rows with missing values\n",
+    "#         df = df.dropna(subset=[\"SUBJECT_ID\", \"HADM_ID\", \"ICD9_CODE\"])\n",
+    "#         # sort by sequence number (i.e., priority)\n",
+    "#         df = df.sort_values([\"SUBJECT_ID\", \"HADM_ID\", \"SEQ_NUM\"], ascending=True)\n",
+    "#         # group by patient and visit\n",
+    "#         group_df = df.groupby(\"SUBJECT_ID\")\n",
+    "\n",
+    "#         # parallel unit of diagnosis (per patient)\n",
+    "#         def diagnosis_unit(p_id, p_info):\n",
+    "#             events = []\n",
+    "#             for v_id, v_info in p_info.groupby(\"HADM_ID\"):\n",
+    "#                 for code in v_info[\"ICD9_CODE\"]:\n",
+    "#                     event = Event(\n",
+    "#                         code=code,\n",
+    "#                         table=table,\n",
+    "#                         vocabulary=\"ICD9CM\",\n",
+    "#                         visit_id=v_id,\n",
+    "#                         patient_id=p_id,\n",
+    "#                     )\n",
+    "#                     events.append(event)\n",
+    "#             return events\n",
+    "\n",
+    "#         # parallel apply\n",
+    "#         group_df = group_df.parallel_apply(\n",
+    "#             lambda x: diagnosis_unit(x.SUBJECT_ID.unique()[0], x)\n",
+    "#         )\n",
+    "\n",
+    "#         # summarize the results\n",
+    "#         patients = self._add_events_to_patient_dict(patients, group_df)\n",
+    "#         return patients\n",
+    "    \n",
+    "    def parse_d_items(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: \n",
+    "        # TODO(botelho3) fill this in to join the text descriptions to the visit.\n",
+    "        return patients\n",
+    "    \n",
+    "    def parse_d_icd_procedures(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: \n",
+    "        # TODO(botelho3) fill this in to join the text descriptions to the visit.\n",
+    "        return patients\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
    "id": "427a42a6-441e-438c-96f4-b38ba82fd192",
    "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Parsing PATIENTS and ADMISSIONS: 100%|███████████████████████████████████████████████| 100/100 [00:00<00:00, 540.68it/s]\n",
+      "Parsing PROCEDURES_ICD: 100%|███████████████████████████████████████████████████████| 113/113 [00:00<00:00, 5934.54it/s]\n",
+      "Parsing PRESCRIPTIONS: 100%|██████████████████████████████████████████████████████████| 122/122 [00:01<00:00, 99.08it/s]\n",
+      "Parsing LABEVENTS: 100%|██████████████████████████████████████████████████████████████| 129/129 [00:06<00:00, 21.36it/s]\n",
+      "Mapping codes: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 214.50it/s]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Statistics of base dataset (dev=False):\n",
+      "\t- Dataset: MIMIC3DatasetWrapper\n",
+      "\t- Number of patients: 100\n",
+      "\t- Number of visits: 129\n",
+      "\t- Number of visits per patient: 1.2900\n",
+      "\t- Number of events per visit in D_ITEMS: 0.0000\n",
+      "\t- Number of events per visit in D_ICD_PROCEDURES: 0.0000\n",
+      "\t- Number of events per visit in PROCEDURES_ICD: 3.9225\n",
+      "\t- Number of events per visit in PRESCRIPTIONS: 115.6667\n",
+      "\t- Number of events per visit in LABEVENTS: 479.1628\n",
+      "\n",
+      "\n",
+      "dataset.patients: patient_id -> <Patient>\n",
+      "\n",
+      "<Patient>\n",
+      "    - visits: visit_id -> <Visit> \n",
+      "    - other patient-level info\n",
+      "    \n",
+      "    <Visit>\n",
+      "        - event_list_dict: table_name -> List[Event]\n",
+      "        - other visit-level info\n",
+      "    \n",
+      "        <Event>\n",
+      "            - code: str\n",
+      "            - other event-level info\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "mimic3base = MIMIC3DatasetWrapper(\n",
+    "    #root=\"https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/\",\n",
+    "    root=os.path.join(os.getcwd(), DATA_DIR_),\n",
+    "    tables=[\"D_ITEMS\", \"D_ICD_PROCEDURES\", \"PROCEDURES_ICD\", \"PRESCRIPTIONS\", \"LABEVENTS\",], # \"D_ICD_DIAGNOSES\", \"DIAGNOSES_ICD\"\n",
+    "    # map all NDC codes to ATC 3-rd level codes in these tables\n",
+    "    # See https://en.wikipedia.org/wiki/Anatomical_Therapeutic_Chemical_Classification_System.\n",
+    "    code_mapping={\"NDC\": (\"ATC\", {\"target_kwargs\": {\"level\": 3}})},\n",
+    ")\n",
+    "\n",
+    "mimic3base.stat()\n",
+    "mimic3base.info()\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b02b8e28-23ac-4ac1-a9b6-ae546e28c905",
+   "metadata": {},
+   "source": [
+    "**Tasks**\n",
+    "\n",
+    "Declare tasks for 2 of the 5 prediction tasks specified in the paper. We will create dataloaders for each task that contain the ICD codes and the raw text for each (patient, visit)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "08692aad-db67-487b-9b26-dbde0ab6adce",
+   "metadata": {},
    "outputs": [],
-   "source": []
+   "source": [
+    "# The original authors tackled 5 tasks\n",
+    "#   1. readmission\n",
+    "#   2. mortality\n",
+    "#   3. an ICU stay exceeding three days\n",
+    "#   4. an ICU stay exceeding seven days\n",
+    "#   5. diagnosis prediction\n",
+    "\n",
+    "def readmission_pred_task(patient):\n",
+    "    \"\"\"\n",
+    "    patient is a <pyhealth.data.Patient> object\n",
+    "    \"\"\"\n",
+    "    samples = []\n",
+    "\n",
+    "    # loop over all visits but the last one\n",
+    "    for i in range(len(patient) - 1):\n",
+    "\n",
+    "        # visit and next_visit are both <pyhealth.data.Visit> objects\n",
+    "        visit = patient[i]\n",
+    "        next_visit = patient[i + 1]\n",
+    "\n",
+    "        # step 1: define the mortality_label\n",
+    "        if next_visit.discharge_status not in [0, 1]:\n",
+    "            mortality_label = 0\n",
+    "        else:\n",
+    "            mortality_label = int(next_visit.discharge_status)\n",
+    "\n",
+    "        # step 2: get code-based feature information\n",
+    "        conditions = visit.get_code_list(table=\"DIAGNOSES_ICD\")\n",
+    "        procedures = visit.get_code_list(table=\"PROCEDURES_ICD\")\n",
+    "        drugs = visit.get_code_list(table=\"PRESCRIPTIONS\")\n",
+    "        labevents = visit.get_code_list(table=\"LABEVENTS\")\n",
+    "\n",
+    "        # step 3: exclusion criteria: visits without condition, procedure, or drug\n",
+    "        if len(conditions) * len(procedures) * len(drugs) == 0: continue\n",
+    "        \n",
+    "        # step 4: assemble the samples\n",
+    "        samples.append(\n",
+    "            {\n",
+    "                \"visit_id\": visit.visit_id,\n",
+    "                \"patient_id\": patient.patient_id,\n",
+    "                # the following keys can be the \"feature_keys\" or \"label_key\" for initializing downstream ML model\n",
+    "                \"conditions\": conditions,\n",
+    "                \"procedures\": procedures,\n",
+    "                \"drugs\": drugs,\n",
+    "                \"label\": mortality_label,\n",
+    "            }\n",
+    "        )\n",
+    "    return samples\n",
+    "\n",
+    "def mortality_pred_task(patient):\n",
+    "    \"\"\"\n",
+    "    patient is a <pyhealth.data.Patient> object\n",
+    "    \"\"\"\n",
+    "    samples = []\n",
+    "\n",
+    "    # loop over all visits but the last one\n",
+    "    for i in range(len(patient) - 1):\n",
+    "\n",
+    "        # visit and next_visit are both <pyhealth.data.Visit> objects\n",
+    "        visit = patient[i]\n",
+    "        next_visit = patient[i + 1]\n",
+    "\n",
+    "        # step 1: define the mortality_label\n",
+    "        if next_visit.discharge_status not in [0, 1]:\n",
+    "            mortality_label = 0\n",
+    "        else:\n",
+    "            mortality_label = int(next_visit.discharge_status)\n",
+    "\n",
+    "        # step 2: get code-based feature information\n",
+    "        conditions = visit.get_code_list(table=\"DIAGNOSES_ICD\")\n",
+    "        procedures = visit.get_code_list(table=\"PROCEDURES_ICD\")\n",
+    "        drugs = visit.get_code_list(table=\"PRESCRIPTIONS\")\n",
+    "\n",
+    "        # step 3: exclusion criteria: visits without condition, procedure, or drug\n",
+    "        if len(conditions) * len(procedures) * len(drugs) == 0: continue\n",
+    "        \n",
+    "        # step 4: assemble the samples\n",
+    "        samples.append(\n",
+    "            {\n",
+    "                \"visit_id\": visit.visit_id,\n",
+    "                \"patient_id\": patient.patient_id,\n",
+    "                # the following keys can be the \"feature_keys\" or \"label_key\" for initializing downstream ML model\n",
+    "                \"conditions\": conditions,\n",
+    "                \"procedures\": procedures,\n",
+    "                \"drugs\": drugs,\n",
+    "                \"label\": mortality_label,\n",
+    "            }\n",
+    "        )\n",
+    "    return samples\n",
+    "\n",
+    "\n",
+    "# mimic3sample = mimic3base.set_task(task_fn=drug_recommendation_mimic3_fn) # use default task\n",
+    "# train_ds, val_ds, test_ds = split_by_patient(mimic3sample, [0.8, 0.1, 0.1])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "96d047ff-6412-462c-9294-a5ba2d3bdba2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "re_dataset = dataset.set_task(readmission_pred_task)\n",
+    "re_dataset.stat()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "70586b66-a814-4898-892b-d1bdd339ce7b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mor_dataset = dataset.set_task(mortality_pred_task)\n",
+    "mor_dataset.stat()"
+   ]
   },
   {
    "cell_type": "code",
@@ -244,6 +528,27 @@
    "id": "1056b60a-5861-4e47-b694-891053cc8470",
    "metadata": {},
    "outputs": [],
+   "source": [
+    "# create dataloaders (torch.data.DataLoader)\n",
+    "train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)\n",
+    "val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)\n",
+    "test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0f9b59b9-4bdb-44c4-b155-0a2a2d0f8d07",
+   "metadata": {},
+   "source": [
+    "### Embed MIMIC III Data using BERT"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1f5e0b1b-d75c-42ee-8d98-879f14b04606",
+   "metadata": {},
+   "outputs": [],
    "source": []
   }
  ],