From 2d2a37d30953af6a2053b3e26b4f2641c6e9836b Mon Sep 17 00:00:00 2001
From: aBotel <admbotel@uwaterloo.ca>
Date: Sat, 15 Apr 2023 02:24:54 -0700
Subject: [PATCH] Add some additional text LUT code and DataLoader skeleton.

---
 src/cs598_dlh_final_project.ipynb | 410 ++++++++++++++++++++++++++++--
 1 file changed, 383 insertions(+), 27 deletions(-)

diff --git a/src/cs598_dlh_final_project.ipynb b/src/cs598_dlh_final_project.ipynb
index 05528cc..e3a3d55 100644
--- a/src/cs598_dlh_final_project.ipynb
+++ b/src/cs598_dlh_final_project.ipynb
@@ -230,7 +230,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 98,
+   "execution_count": 119,
    "id": "b881e548-4d27-4725-8c47-b2612157929e",
    "metadata": {},
    "outputs": [],
@@ -250,6 +250,7 @@
     "    def __init__(self, *args, **kwargs):\n",
     "        self._valid_text_tables = [\"D_ICD_DIAGNOSES\", \"D_ITEMS\", \"D_ICD_PROCEDURES\", \"D_LABITEMS\"]\n",
     "        self._text_descriptions = {x: {} for x in self._valid_text_tables}\n",
+    "        self._text_luts = {x: {} for x in self._valid_text_tables}\n",
     "        super().__init__(*args, **kwargs)\n",
     "    \n",
     "    def get_all_tables(self) -> List[str]: \n",
@@ -258,6 +259,12 @@
     "    def get_text_dict(self, table_name: str) -> Dict[str, Dict[Any, Any]]:\n",
     "        return self._text_descriptions.get(table_name)\n",
     "    \n",
+    "    def set_text_lut(self, table_name: str, lut: Dict[Any, Any]) -> None:\n",
+    "        self._text_luts[table_name] = lut\n",
+    "    \n",
+    "    def get_text_lut(self, table_name: str) -> Dict[Any, Any]:\n",
+    "        return self._text_luts[table_name]\n",
+    "    \n",
     "    # Note the name has to match the table name exactly.\n",
     "    # See https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/datasets/mimic3.py#L71.\n",
     "    def parse_d_icd_diagnoses(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: \n",
@@ -408,7 +415,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 96,
+   "execution_count": 124,
    "id": "427a42a6-441e-438c-96f4-b38ba82fd192",
    "metadata": {},
    "outputs": [
@@ -416,7 +423,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Parsing PATIENTS and ADMISSIONS: 100%|███████████████████████████████████████████████| 100/100 [00:00<00:00, 616.21it/s]\n"
+      "Parsing PATIENTS and ADMISSIONS: 100%|███████████████████████████████████████████████| 100/100 [00:00<00:00, 594.73it/s]\n"
      ]
     },
     {
@@ -433,11 +440,10 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Parsing DIAGNOSES_ICD: 100%|████████████████████████████████████████████████████████| 129/129 [00:00<00:00, 5045.37it/s]\n",
-      "Parsing PROCEDURES_ICD: 100%|███████████████████████████████████████████████████████| 113/113 [00:00<00:00, 6525.72it/s]\n",
-      "Parsing PRESCRIPTIONS: 100%|█████████████████████████████████████████████████████████| 122/122 [00:00<00:00, 136.63it/s]\n",
-      "Parsing LABEVENTS: 100%|██████████████████████████████████████████████████████████████| 129/129 [00:05<00:00, 25.78it/s]\n",
-      "Mapping codes: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 244.02it/s]\n"
+      "Parsing DIAGNOSES_ICD: 100%|████████████████████████████████████████████████████████| 129/129 [00:00<00:00, 5116.36it/s]\n",
+      "Parsing PROCEDURES_ICD: 100%|███████████████████████████████████████████████████████| 113/113 [00:00<00:00, 6220.39it/s]\n",
+      "Parsing PRESCRIPTIONS: 100%|█████████████████████████████████████████████████████████| 122/122 [00:00<00:00, 137.11it/s]\n",
+      "Mapping codes: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 253.20it/s]\n"
      ]
     },
     {
@@ -457,7 +463,6 @@
       "\t- Number of events per visit in DIAGNOSES_ICD: 13.6512\n",
       "\t- Number of events per visit in PROCEDURES_ICD: 3.9225\n",
       "\t- Number of events per visit in PRESCRIPTIONS: 115.6667\n",
-      "\t- Number of events per visit in LABEVENTS: 479.1628\n",
       "\n",
       "\n",
       "dataset.patients: patient_id -> <Patient>\n",
@@ -479,15 +484,15 @@
    ],
    "source": [
     "mimic3base = MIMIC3DatasetWrapper(\n",
-    "    #root=\"https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/\",\n",
+    "    # root=\"https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/\",\n",
     "    root=os.path.join(os.getcwd(), DATA_DIR_),\n",
     "    tables=[\"D_ICD_DIAGNOSES\", \"D_ICD_PROCEDURES\", \"D_ITEMS\", \"D_LABITEMS\",\n",
-    "            \"DIAGNOSES_ICD\", \"PROCEDURES_ICD\", \"PRESCRIPTIONS\", \"LABEVENTS\"],\n",
+    "            \"DIAGNOSES_ICD\", \"PROCEDURES_ICD\", \"PRESCRIPTIONS\"], # \"LABEVENTS\"],\n",
     "    # map all NDC codes to ATC 3-rd level codes in these tables\n",
     "    # See https://en.wikipedia.org/wiki/Anatomical_Therapeutic_Chemical_Classification_System.\n",
     "    code_mapping={\"NDC\": (\"ATC\", {\"target_kwargs\": {\"level\": 3}})},\n",
     "    # Slow\n",
-    "    refresh_cache=True,\n",
+    "    refresh_cache=False,\n",
     ")\n",
     "\n",
     "mimic3base.stat()\n",
@@ -496,7 +501,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 100,
+   "execution_count": 125,
    "id": "4cef4476-6bc8-4791-b246-4c329a80a2e4",
    "metadata": {},
    "outputs": [
@@ -504,7 +509,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "dict_keys(['D_ICD_DIAGNOSES', 'D_ITEMS', 'D_ICD_PROCEDURES', 'D_LABITEMS'])\n",
+      "['D_ICD_DIAGNOSES', 'D_ITEMS', 'D_ICD_PROCEDURES', 'D_LABITEMS']\n",
       "Table: D_ICD_DIAGNOSES\n",
       "[['0010', 'Cholera d/t vib cholerae', 'Cholera due to vibrio cholerae'], ['0011', 'Cholera d/t vib el tor', 'Cholera due to vibrio cholerae el tor'], ['0019', 'Cholera NOS', 'Cholera, unspecified'], ['0020', 'Typhoid fever', 'Typhoid fever'], ['0021', 'Paratyphoid fever a', 'Paratyphoid fever A']]\n",
       "\n",
@@ -536,7 +541,14 @@
     "    d = mimic3base.get_text_dict(t)\n",
     "    print(f\"Table: {t}\")\n",
     "    print(d['data'][:5])\n",
-    "    print('\\n\\n')\n"
+    "    print('\\n\\n')\n",
+    "\n",
+    "\n",
+    "for t in table_names:\n",
+    "    d = mimic3base.get_text_dict(t)\n",
+    "    d = d['data']\n",
+    "    lut = {record[0]: record[1:] for record in d}\n",
+    "    mimic3base.set_text_lut(t, lut)"
    ]
   },
   {
@@ -551,7 +563,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 115,
    "id": "08692aad-db67-487b-9b26-dbde0ab6adce",
    "metadata": {},
    "outputs": [],
@@ -573,6 +585,7 @@
     "    for i in range(len(patient) - 1):\n",
     "\n",
     "        # visit and next_visit are both <pyhealth.data.Visit> objects\n",
+    "        # there are no vists.attr_dict keys\n",
     "        visit = patient[i]\n",
     "        next_visit = patient[i + 1]\n",
     "\n",
@@ -585,11 +598,20 @@
     "        # step 2: get code-based feature information\n",
     "        conditions = visit.get_code_list(table=\"DIAGNOSES_ICD\")\n",
     "        procedures = visit.get_code_list(table=\"PROCEDURES_ICD\")\n",
-    "        drugs = visit.get_code_list(table=\"PRESCRIPTIONS\")\n",
-    "        labevents = visit.get_code_list(table=\"LABEVENTS\")\n",
+    "        # drugs = visit.get_code_list(table=\"PRESCRIPTIONS\")\n",
+    "        # TODO(botelho3) - add this datasource back in once we have full MIMIC-III dataset.\n",
+    "        # labevents = visit.get_code_list(table=\"LABEVENTS\")\n",
     "\n",
     "        # step 3: exclusion criteria: visits without condition, procedure, or drug\n",
-    "        if len(conditions) * len(procedures) * len(drugs) == 0: continue\n",
+    "        if len(conditions) * len(procedures) * len(labevents) == 0: continue\n",
+    "        \n",
+    "        # step 3.5: build text lists from the ICD codes\n",
+    "        d_diag = mimic3base._text_luts.get(\"D_DIAGNOSES_ICD\").get()\n",
+    "        d_proc = mimic3base._text_luts.get(\"D_PROCEDURES_ICD\").get()\n",
+    "        # Index 0 is shortname, index 1 is longname.\n",
+    "        conditions_text = [d_diag.get(cond)[1] for cond in conditions]\n",
+    "        procedutes_text = [d_proc.get(proc)[1] for proc in procedures]\n",
+    "        # labevents_text =\n",
     "        \n",
     "        # step 4: assemble the samples\n",
     "        samples.append(\n",
@@ -599,7 +621,11 @@
     "                # the following keys can be the \"feature_keys\" or \"label_key\" for initializing downstream ML model\n",
     "                \"conditions\": conditions,\n",
     "                \"procedures\": procedures,\n",
-    "                \"drugs\": drugs,\n",
+    "                \"conditions_text\": conditions_text,\n",
+    "                \"procedures_text\": procedures_text,\n",
+    "                # \"drugs\": drugs,\n",
+    "                # \"labevents\": labevents,\n",
+    "                # \"labevents_text\": labevents_text\n",
     "                \"label\": mortality_label,\n",
     "            }\n",
     "        )\n",
@@ -628,6 +654,7 @@
     "        conditions = visit.get_code_list(table=\"DIAGNOSES_ICD\")\n",
     "        procedures = visit.get_code_list(table=\"PROCEDURES_ICD\")\n",
     "        drugs = visit.get_code_list(table=\"PRESCRIPTIONS\")\n",
+    "        labevents = visit.get_code_list(table=\"LABEVENTS\")\n",
     "\n",
     "        # step 3: exclusion criteria: visits without condition, procedure, or drug\n",
     "        if len(conditions) * len(procedures) * len(drugs) == 0: continue\n",
@@ -653,13 +680,224 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 105,
    "id": "96d047ff-6412-462c-9294-a5ba2d3bdba2",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Generating samples for readmission_pred_task:  44%|███████████████▍                   | 44/100 [00:00<00:00, 386.68it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "visit num_events: 598\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 252\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 414\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 276\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 314\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 774\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 1723\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 635\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 856\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 288\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 252\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 393\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 641\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 310\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 361\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 178\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 385\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Generating samples for readmission_pred_task: 100%|██████████████████████████████████| 100/100 [00:00<00:00, 328.61it/s]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "visit num_events: 908\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 498\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 1178\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 822\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 185\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 1496\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 177\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 1036\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 130\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 369\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 268\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "visit num_events: 136\n",
+      "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n",
+      "visit attrs\n",
+      "{}\n",
+      "Statistics of sample dataset:\n",
+      "\t- Dataset: MIMIC3DatasetWrapper\n",
+      "\t- Task: readmission_pred_task\n",
+      "\t- Number of samples: 27\n",
+      "\t- Number of patients: 13\n",
+      "\t- Number of visits: 27\n",
+      "\t- Number of visits per patient: 2.0769\n",
+      "\t- conditions:\n",
+      "\t\t- Number of conditions per sample: 1.0000\n",
+      "\t\t- Number of unique conditions: 17\n",
+      "\t\t- Distribution of conditions (Top-10): [('0389', 8), ('486', 3), ('03849', 2), ('5715', 1), ('41071', 1), ('5750', 1), ('51884', 1), ('20280', 1), ('53642', 1), ('51881', 1)]\n",
+      "\t- procedures:\n",
+      "\t\t- Number of procedures per sample: 3.4074\n",
+      "\t\t- Number of unique procedures: 42\n",
+      "\t\t- Distribution of procedures (Top-10): [('966', 15), ('3893', 14), ('9671', 5), ('9672', 4), ('9604', 4), ('4513', 3), ('3891', 3), ('9915', 3), ('9904', 3), ('3995', 2)]\n",
+      "\t- drugs:\n",
+      "\t\t- Number of drugs per sample: 40.7037\n",
+      "\t\t- Number of unique drugs: 112\n",
+      "\t\t- Distribution of drugs (Top-10): [('B05X', 27), ('B01A', 25), ('N02A', 25), ('A02B', 24), ('A06A', 24), ('N02B', 23), ('J01X', 21), ('A12C', 20), ('A12B', 20), ('V04C', 19)]\n",
+      "\t- label:\n",
+      "\t\t- Number of label per sample: 1.0000\n",
+      "\t\t- Number of unique label: 2\n",
+      "\t\t- Distribution of label (Top-10): [(0, 26), (1, 1)]\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'visit_id': '122098',\n",
+       " 'patient_id': '10059',\n",
+       " 'conditions': '5715',\n",
+       " 'procedures': ['4513',\n",
+       "  '3891',\n",
+       "  '3893',\n",
+       "  '9672',\n",
+       "  '9915',\n",
+       "  '3895',\n",
+       "  '3995',\n",
+       "  '5491'],\n",
+       " 'drugs': ['A12C',\n",
+       "  'B05X',\n",
+       "  'V04C',\n",
+       "  'A02B',\n",
+       "  'J01F',\n",
+       "  'N05C',\n",
+       "  'N01A',\n",
+       "  'A12A',\n",
+       "  'H01C',\n",
+       "  'B02B',\n",
+       "  'V06D',\n",
+       "  'A10A',\n",
+       "  'J01M',\n",
+       "  'C01C',\n",
+       "  'B01A',\n",
+       "  'B05B',\n",
+       "  'C05B',\n",
+       "  'J01X',\n",
+       "  'A06A',\n",
+       "  'B05A',\n",
+       "  'H02A',\n",
+       "  'N02A',\n",
+       "  'A04A'],\n",
+       " 'label': 0}"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
-    "re_dataset = dataset.set_task(readmission_pred_task)\n",
-    "re_dataset.stat()"
+    "readm_dataset = mimic3base.set_task(readmission_pred_task)\n",
+    "readm_dataset.stat()\n",
+    "readm_dataset.samples[0]"
    ]
   },
   {
@@ -669,8 +907,72 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "mor_dataset = dataset.set_task(mortality_pred_task)\n",
-    "mor_dataset.stat()"
+    "mor_dataset = mimic3base.set_task(mortality_pred_task)\n",
+    "mor_dataset.stat()\n",
+    "mor_dataset.samples[0]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "66a0281a-f95b-4991-abd5-5824fc37c520",
+   "metadata": {},
+   "source": [
+    "### DataLoaders and Collate"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "924c8e3c-0482-4688-b2d6-d83d490060a9",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# def collate_fn(data):\n",
+    "#     \"\"\"\n",
+    "#     Collate the the list of samples into batches. For each patient, you need to pad the diagnosis\n",
+    "#         sequences to the sample shape (max # visits, len(freq_codes)). The padding infomation\n",
+    "#         is stored in `mask`.\n",
+    "    \n",
+    "#     Arguments:\n",
+    "#         data: a list of samples fetched from `CustomDataset`\n",
+    "        \n",
+    "#     Outputs:\n",
+    "#         x: a tensor of shape (# patiens, max # visits, len(freq_codes)) of type torch.float\n",
+    "#         masks: a tensor of shape (# patiens, max # visits) of type torch.bool\n",
+    "#         y: a tensor of shape (# patiens) of type torch.float\n",
+    "        \n",
+    "#     Note that you can obtains the list of diagnosis codes and the list of hf labels\n",
+    "#         using: `sequences, labels = zip(*data)`\n",
+    "#     \"\"\"\n",
+    "\n",
+    "#     sequences, labels = zip(*data)\n",
+    "\n",
+    "#     y = torch.tensor(labels, dtype=torch.float)\n",
+    "    \n",
+    "#     num_patients = len(sequences)\n",
+    "#     num_visits = [len(patient) for patient in sequences]\n",
+    "#     max_num_visits = max(num_visits)\n",
+    "    \n",
+    "#     x = torch.zeros((num_patients, max_num_visits, len(freq_codes)), dtype=torch.float)    \n",
+    "#     for i_patient, patient in enumerate(sequences):\n",
+    "#         kMaxVisits = len(patient)\n",
+    "#         for j_visit, visit in enumerate(patient):\n",
+    "#             l = len(visit)\n",
+    "#             for code in visit:\n",
+    "#                 \"\"\"\n",
+    "#                 TODO: 1. check if code is in freq_codes;\n",
+    "#                       2. obtain the code index using code2idx;\n",
+    "#                       3. set the correspoindg element in x to 1.\n",
+    "#                 \"\"\"\n",
+    "#                 try:\n",
+    "#                     idx = code2idx[code]\n",
+    "#                     x[i_patient, j_visit, idx] = 1\n",
+    "#                 except KeyError as e:\n",
+    "#                     pass\n",
+    "    \n",
+    "#     masks = torch.sum(x, dim=-1) > 0\n",
+    "    \n",
+    "#     return x, masks, y"
    ]
   },
   {
@@ -683,7 +985,61 @@
     "# create dataloaders (torch.data.DataLoader)\n",
     "train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)\n",
     "val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)\n",
-    "test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)\n"
+    "test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)\n",
+    "\n",
+    "# from torch.utils.data import DataLoader\n",
+    "\n",
+    "# def load_data(train_dataset, val_dataset, collate_fn):\n",
+    "    \n",
+    "#     '''\n",
+    "#     TODO: Implement this function to return the data loader for  train and validation dataset. \n",
+    "#     Set batchsize to 32. Set `shuffle=True` only for train dataloader.\n",
+    "    \n",
+    "#     Arguments:\n",
+    "#         train dataset: train dataset of type `CustomDataset`\n",
+    "#         val dataset: validation dataset of type `CustomDataset`\n",
+    "#         collate_fn: collate function\n",
+    "        \n",
+    "#     Outputs:\n",
+    "#         train_loader, val_loader: train and validation dataloaders\n",
+    "    \n",
+    "#     Note that you need to pass the collate function to the data loader `collate_fn()`.\n",
+    "#     '''\n",
+    "    \n",
+    "#     batch_size = 32\n",
+    "#     train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)\n",
+    "#     val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)\n",
+    "    \n",
+    "\n",
+    "\n",
+    "#     return train_loader, val_loader\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "035fccee-d5da-428c-a54a-0678a2d19b5e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Verify DataLoader properties.\n",
+    "\n",
+    "# from torch.utils.data import DataLoader\n",
+    "\n",
+    "# loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)\n",
+    "# loader_iter = iter(loader)\n",
+    "# x, masks, y = next(loader_iter)\n",
+    "\n",
+    "# assert x.dtype == torch.float\n",
+    "# assert y.dtype == torch.float\n",
+    "# assert masks.dtype == torch.bool\n",
+    "\n",
+    "# assert x.shape == (10, 3, 105)\n",
+    "# assert y.shape == (10,)\n",
+    "# assert masks.shape == (10, 3)\n",
+    "\n",
+    "# assert x[0][0].sum() == 9\n",
+    "# assert masks[0].sum() == 2"
    ]
   },
   {
-- 
GitLab