diff --git a/src/cs598_dlh_final_project.ipynb b/src/cs598_dlh_final_project.ipynb
index e323ee8849a6bbf0d50686671ac2212428e59407..86824ad3ef2df6f4feb7136a5cc23abd757a12f0 100644
--- a/src/cs598_dlh_final_project.ipynb
+++ b/src/cs598_dlh_final_project.ipynb
@@ -10,7 +10,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 117,
    "id": "5da368cd-3045-4ec0-86fb-2ce15b5d1b92",
    "metadata": {},
    "outputs": [
@@ -30,6 +30,7 @@
     "import math\n",
     "import itertools\n",
     "import functools\n",
+    "from copy import deepcopy\n",
     "#from termcolor import colored, cprint\n",
     "import colored\n",
     "from datetime import datetime, timedelta\n",
@@ -75,7 +76,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 129,
    "id": "ebef3528-0396-459a-8112-b272089f5d67",
    "metadata": {},
    "outputs": [],
@@ -86,6 +87,7 @@
     "BATCH_SIZE_ = 32\n",
     "EMBEDDING_DIM_ = 264  # BERT requires a multiple of 12\n",
     "SHUFFLE_ = True\n",
+    "SAMPLE_MULTIPLIER_ = 3\n",
     "# Set seed for reproducibility.\n",
     "seed = 90210\n",
     "random.seed(seed)\n",
@@ -203,33 +205,33 @@
     "    def get_text_lut(self, table_name: str) -> Dict[Any, Any]:\n",
     "        return self._text_luts[table_name]\n",
     "    \n",
-    "    def _validate(self) -> Dict:\n",
-    "        \"\"\"Helper function which validates the samples.\n",
-    "        Will be called in `self.__init__()`.\n",
-    "        Returns:\n",
-    "            input_info: Dict, a dict whose keys are the same as the keys in the\n",
-    "                samples, and values are the corresponding input information:\n",
-    "                - \"type\": the element type of each key attribute, one of float,\n",
-    "                    int, str.\n",
-    "                - \"dim\": the list dimension of each key attribute, one of 0, 1, 2, 3.\n",
-    "                - \"len\": the length of the vector, only valid for vector-based\n",
-    "                    attributes.\n",
-    "        \"\"\"\n",
-    "        \"\"\" 1. Check if all samples are of type dict. \"\"\"\n",
-    "        assert all(\n",
-    "            [isinstance(s, dict) for s in self.samples],\n",
-    "        ), \"Each sample should be a dict\"\n",
-    "        keys = self.samples[0].keys()\n",
-    "\n",
-    "        \"\"\" 2. Check if all samples have the same keys. \"\"\"\n",
-    "        assert all(\n",
-    "            [set(s.keys()) == set(keys) for s in self.samples]\n",
-    "        ), \"All samples should have the same keys\"\n",
-    "\n",
-    "        \"\"\" 3. Check if \"patient_id\" and \"visit_id\" are in the keys.\"\"\"\n",
-    "        assert \"patient_id\" in keys, \"patient_id should be in the keys\"\n",
-    "        assert \"visit_id\" in keys, \"visit_id should be in the keys\"\n",
-    "        return {}\n",
+    "#     def _validate(self) -> Dict:\n",
+    "#         \"\"\"Helper function which validates the samples.\n",
+    "#         Will be called in `self.__init__()`.\n",
+    "#         Returns:\n",
+    "#             input_info: Dict, a dict whose keys are the same as the keys in the\n",
+    "#                 samples, and values are the corresponding input information:\n",
+    "#                 - \"type\": the element type of each key attribute, one of float,\n",
+    "#                     int, str.\n",
+    "#                 - \"dim\": the list dimension of each key attribute, one of 0, 1, 2, 3.\n",
+    "#                 - \"len\": the length of the vector, only valid for vector-based\n",
+    "#                     attributes.\n",
+    "#         \"\"\"\n",
+    "#         \"\"\" 1. Check if all samples are of type dict. \"\"\"\n",
+    "#         assert all(\n",
+    "#             [isinstance(s, dict) for s in self.samples],\n",
+    "#         ), \"Each sample should be a dict\"\n",
+    "#         keys = self.samples[0].keys()\n",
+    "\n",
+    "#         \"\"\" 2. Check if all samples have the same keys. \"\"\"\n",
+    "#         assert all(\n",
+    "#             [set(s.keys()) == set(keys) for s in self.samples]\n",
+    "#         ), \"All samples should have the same keys\"\n",
+    "\n",
+    "#         \"\"\" 3. Check if \"patient_id\" and \"visit_id\" are in the keys.\"\"\"\n",
+    "#         assert \"patient_id\" in keys, \"patient_id should be in the keys\"\n",
+    "#         assert \"visit_id\" in keys, \"visit_id should be in the keys\"\n",
+    "#         return {}\n",
     "\n",
     "    \n",
     "    def _add_events_to_patient_dict(\n",
@@ -866,7 +868,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 96,
+   "execution_count": 130,
    "id": "634ef5a4-fa6d-4720-a4e7-e91cc3746ef5",
    "metadata": {},
    "outputs": [],
@@ -885,8 +887,9 @@
     "    samples = []\n",
     "    visits = []\n",
     "    kMaxListSize = 40\n",
-    "    \n",
-    "    if len(patient) <= 1:\n",
+    "   \n",
+    "    # Length 1 patients by defn are not readmitted.\n",
+    "    if len(patient) < 1:\n",
     "        return samples\n",
     "\n",
     "    # we will drop the last visit\n",
@@ -1002,8 +1005,8 @@
     "    #     \"label\": global_readmission_label,\n",
     "    # }\n",
     "\n",
-    "    samples.append(\n",
-    "        sample\n",
+    "    samples.extend(\n",
+    "        list(deepcopy(sample) for s in range(SAMPLE_MULTIPLIER_))\n",
     "    )\n",
     "    return samples\n",
     "\n",
@@ -1148,15 +1151,15 @@
     "    #     \"label\": global_mortality_label,\n",
     "    # }\n",
     "\n",
-    "    samples.append(\n",
-    "        sample\n",
+    "    samples.extend(\n",
+    "        list(deepcopy(sample) for s in range(SAMPLE_MULTIPLIER_))\n",
     "    )\n",
     "    return samples"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 94,
+   "execution_count": 131,
    "id": "96d047ff-6412-462c-9294-a5ba2d3bdba2",
    "metadata": {},
    "outputs": [
@@ -1164,7 +1167,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Generating samples for readmission_pred_task_per_patient: 100%|██████| 100/100 [00:00<00:00, 9155.87it/s]"
+      "Generating samples for readmission_pred_task_per_patient: 100%|██████| 100/100 [00:00<00:00, 1358.65it/s]"
      ]
     },
     {
@@ -1174,66 +1177,66 @@
       "Statistics of sample dataset:\n",
       "\t- Dataset: MIMIC3DatasetWrapper\n",
       "\t- Task: readmission_pred_task_per_patient\n",
-      "\t- Number of samples: 14\n",
-      "\t- Number of patients: 14\n",
-      "\t- Number of visits: 14\n",
-      "\t- Number of visits per patient: 1.0000\n",
+      "\t- Number of samples: 261\n",
+      "\t- Number of patients: 87\n",
+      "\t- Number of visits: 87\n",
+      "\t- Number of visits per patient: 3.0000\n",
       "\t- num_visits:\n",
       "\t\t- Number of num_visits per sample: 1.0000\n",
       "\t\t- Number of unique num_visits: 4\n",
-      "\t\t- Distribution of num_visits (Top-10): [(2, 8), (1, 3), (3, 2), (14, 1)]\n",
+      "\t\t- Distribution of num_visits (Top-10): [(1, 228), (2, 24), (3, 6), (14, 3)]\n",
       "\t- conditions:\n",
-      "\t\t- Number of conditions per sample: 50.0000\n",
-      "\t\t- Number of unique conditions: 210\n",
-      "\t\t- Distribution of conditions (Top-10): [('0', 323), ('5849', 13), ('51881', 9), ('4019', 9), ('486', 9), ('99592', 7), ('4280', 7), ('42731', 7), ('0389', 6), ('78552', 6)]\n",
+      "\t\t- Number of conditions per sample: 40.0000\n",
+      "\t\t- Number of unique conditions: 482\n",
+      "\t\t- Distribution of conditions (Top-10): [('0', 6726), ('4019', 102), ('42731', 99), ('4280', 99), ('5849', 99), ('51881', 78), ('25000', 69), ('486', 69), ('99592', 51), ('0389', 48)]\n",
       "\t- conditions_text:\n",
-      "\t\t- Number of conditions_text per sample: 50.0000\n",
-      "\t\t- Number of unique conditions_text: 202\n",
-      "\t\t- Distribution of conditions_text (Top-10): [('', 333), ('Acute kidney failure, unspecified', 13), ('Acute respiratory failure', 9), ('Unspecified essential hypertension', 9), ('Pneumonia, organism unspecified', 9), ('Severe sepsis', 7), ('Congestive heart failure, unspecified', 7), ('Atrial fibrillation', 7), ('Unspecified septicemia', 6), ('Septic shock', 6)]\n",
+      "\t\t- Number of conditions_text per sample: 40.0000\n",
+      "\t\t- Number of unique conditions_text: 465\n",
+      "\t\t- Distribution of conditions_text (Top-10): [('', 6837), ('Unspecified essential hypertension', 102), ('Atrial fibrillation', 99), ('Congestive heart failure, unspecified', 99), ('Acute kidney failure, unspecified', 99), ('Acute respiratory failure', 78), ('Diabetes mellitus without mention of complication, type II or unspecified type, not stated as uncontrolled', 69), ('Pneumonia, organism unspecified', 69), ('Severe sepsis', 51), ('Unspecified septicemia', 48)]\n",
       "\t- procedures:\n",
-      "\t\t- Number of procedures per sample: 50.0000\n",
-      "\t\t- Number of unique procedures: 55\n",
-      "\t\t- Distribution of procedures (Top-10): [('0', 574), ('3893', 18), ('966', 17), ('9604', 8), ('9672', 7), ('9671', 7), ('4513', 5), ('3891', 4), ('9904', 4), ('9915', 3)]\n",
+      "\t\t- Number of procedures per sample: 40.0000\n",
+      "\t\t- Number of unique procedures: 151\n",
+      "\t\t- Distribution of procedures (Top-10): [('0', 9264), ('3893', 114), ('966', 78), ('9604', 57), ('9671', 54), ('9672', 54), ('9904', 42), ('5491', 27), ('9915', 27), ('3995', 21)]\n",
       "\t- procedures_text:\n",
-      "\t\t- Number of procedures_text per sample: 50.0000\n",
-      "\t\t- Number of unique procedures_text: 55\n",
-      "\t\t- Distribution of procedures_text (Top-10): [('', 574), ('Venous catheterization, not elsewhere classified', 18), ('Enteral infusion of concentrated nutritional substances', 17), ('Insertion of endotracheal tube', 8), ('Continuous invasive mechanical ventilation for 96 consecutive hours or more', 7), ('Continuous invasive mechanical ventilation for less than 96 consecutive hours', 7), ('Other endoscopy of small intestine', 5), ('Arterial catheterization', 4), ('Transfusion of packed cells', 4), ('Parenteral infusion of concentrated nutritional substances', 3)]\n",
+      "\t\t- Number of procedures_text per sample: 40.0000\n",
+      "\t\t- Number of unique procedures_text: 148\n",
+      "\t\t- Distribution of procedures_text (Top-10): [('', 9273), ('Venous catheterization, not elsewhere classified', 114), ('Enteral infusion of concentrated nutritional substances', 78), ('Insertion of endotracheal tube', 57), ('Continuous invasive mechanical ventilation for less than 96 consecutive hours', 54), ('Continuous invasive mechanical ventilation for 96 consecutive hours or more', 54), ('Transfusion of packed cells', 42), ('Percutaneous abdominal drainage', 27), ('Parenteral infusion of concentrated nutritional substances', 27), ('Hemodialysis', 21)]\n",
       "\t- drugs:\n",
-      "\t\t- Number of drugs per sample: 50.0000\n",
-      "\t\t- Number of unique drugs: 103\n",
-      "\t\t- Distribution of drugs (Top-10): [('0', 141), ('B05X', 22), ('A02B', 20), ('B01A', 19), ('A06A', 17), ('N02A', 17), ('N02B', 17), ('V04C', 15), ('A12C', 14), ('B05B', 13)]\n",
+      "\t\t- Number of drugs per sample: 37.7586\n",
+      "\t\t- Number of unique drugs: 137\n",
+      "\t\t- Distribution of drugs (Top-10): [('0', 3492), ('B05X', 252), ('A02B', 225), ('B01A', 219), ('N02B', 213), ('A06A', 204), ('N02A', 195), ('V06D', 189), ('A12C', 171), ('V04C', 168)]\n",
       "\t- drugs_text:\n",
-      "\t\t- Number of drugs_text per sample: 50.0000\n",
-      "\t\t- Number of unique drugs_text: 251\n",
-      "\t\t- Distribution of drugs_text (Top-10): [('', 55), ('Neutra-Phos MAIN Powder Packet PO 0.0', 35), ('Bisacodyl MAIN 5 mg Tab PO 0.0', 28), ('Bisacodyl MAIN 10mg Suppository PR 0.0', 21), ('Furosemide MAIN 40mg/4mL Vial IV 0.0', 16), ('Magnesium Sulfate MAIN 2 g / 50 mL Premix Bag IV 0.0', 15), ('Nephrocaps MAIN 1 CAP PPK PO 0.0', 10), ('Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0', 9), ('NS BASE 500mL Bag IV 0.0', 8), ('Albuterol 0.083% Neb Soln MAIN 0.083%;3mL Vial IH 0.0', 8)]\n",
+      "\t\t- Number of drugs_text per sample: 37.7586\n",
+      "\t\t- Number of unique drugs_text: 700\n",
+      "\t\t- Distribution of drugs_text (Top-10): [('', 864), ('Magnesium Sulfate MAIN 2 g / 50 mL Premix Bag IV 0.0', 294), ('Neutra-Phos MAIN Powder Packet PO 0.0', 273), ('Bisacodyl MAIN 5 mg Tab PO 0.0', 180), ('Furosemide MAIN 40mg/4mL Vial IV 0.0', 168), ('Bisacodyl MAIN 10mg Suppository PR 0.0', 162), ('Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0', 117), ('Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0', 111), ('0.9% Sodium Chloride BASE 1000mL Bag IV 0.0', 111), ('D5W BASE 250mL Bag IV DRIP 0.0', 105)]\n",
       "\t- conditions_pad:\n",
       "\t\t- Number of conditions_pad per sample: 1.0000\n",
-      "\t\t- Number of unique conditions_pad: 11\n",
-      "\t\t- Distribution of conditions_pad (Top-10): [(49, 3), (33, 2), (22, 1), (19, 1), (8, 1), (15, 1), (34, 1), (13, 1), (21, 1), (18, 1)]\n",
+      "\t\t- Number of unique conditions_pad: 26\n",
+      "\t\t- Distribution of conditions_pad (Top-10): [(8, 51), (13, 21), (10, 18), (14, 18), (7, 15), (22, 12), (9, 12), (17, 12), (39, 9), (4, 9)]\n",
       "\t- procedures_pad:\n",
       "\t\t- Number of procedures_pad per sample: 1.0000\n",
-      "\t\t- Number of unique procedures_pad: 10\n",
-      "\t\t- Distribution of procedures_pad (Top-10): [(9, 2), (0, 2), (4, 2), (5, 2), (2, 1), (11, 1), (7, 1), (15, 1), (45, 1), (10, 1)]\n",
+      "\t\t- Number of unique procedures_pad: 18\n",
+      "\t\t- Distribution of procedures_pad (Top-10): [(0, 48), (2, 42), (1, 36), (3, 24), (5, 21), (7, 18), (6, 12), (4, 12), (9, 9), (10, 9)]\n",
       "\t- conditions_text_pad:\n",
       "\t\t- Number of conditions_text_pad per sample: 1.0000\n",
-      "\t\t- Number of unique conditions_text_pad: 11\n",
-      "\t\t- Distribution of conditions_text_pad (Top-10): [(49, 3), (33, 2), (22, 1), (19, 1), (8, 1), (15, 1), (34, 1), (13, 1), (21, 1), (18, 1)]\n",
+      "\t\t- Number of unique conditions_text_pad: 26\n",
+      "\t\t- Distribution of conditions_text_pad (Top-10): [(8, 51), (13, 21), (10, 18), (14, 18), (7, 15), (22, 12), (9, 12), (17, 12), (39, 9), (4, 9)]\n",
       "\t- procedures_text_pad:\n",
       "\t\t- Number of procedures_text_pad per sample: 1.0000\n",
-      "\t\t- Number of unique procedures_text_pad: 10\n",
-      "\t\t- Distribution of procedures_text_pad (Top-10): [(9, 2), (0, 2), (4, 2), (5, 2), (2, 1), (11, 1), (7, 1), (15, 1), (45, 1), (10, 1)]\n",
+      "\t\t- Number of unique procedures_text_pad: 18\n",
+      "\t\t- Distribution of procedures_text_pad (Top-10): [(0, 48), (2, 42), (1, 36), (3, 24), (5, 21), (7, 18), (6, 12), (4, 12), (9, 9), (10, 9)]\n",
       "\t- drugs_pad:\n",
       "\t\t- Number of drugs_pad per sample: 1.0000\n",
-      "\t\t- Number of unique drugs_pad: 7\n",
-      "\t\t- Distribution of drugs_pad (Top-10): [(49, 7), (36, 2), (26, 1), (30, 1), (5, 1), (39, 1), (44, 1)]\n",
+      "\t\t- Number of unique drugs_pad: 31\n",
+      "\t\t- Distribution of drugs_pad (Top-10): [(39, 57), (17, 18), (-1, 15), (22, 15), (14, 15), (19, 12), (25, 12), (24, 9), (12, 9), (36, 9)]\n",
       "\t- drugs_text_pad:\n",
       "\t\t- Number of drugs_text_pad per sample: 1.0000\n",
-      "\t\t- Number of unique drugs_text_pad: 2\n",
-      "\t\t- Distribution of drugs_text_pad (Top-10): [(49, 13), (8, 1)]\n",
+      "\t\t- Number of unique drugs_text_pad: 16\n",
+      "\t\t- Distribution of drugs_text_pad (Top-10): [(39, 189), (-1, 15), (24, 9), (31, 6), (37, 6), (34, 6), (27, 3), (32, 3), (36, 3), (38, 3)]\n",
       "\t- label:\n",
       "\t\t- Number of label per sample: 1.0000\n",
       "\t\t- Number of unique label: 2\n",
-      "\t\t- Distribution of label (Top-10): [(1, 10), (0, 4)]\n"
+      "\t\t- Distribution of label (Top-10): [(0, 231), (1, 30)]\n"
      ]
     },
     {
@@ -1246,39 +1249,29 @@
     {
      "data": {
       "text/plain": [
-       "{'patient_id': '10059',\n",
-       " 'visit_id': '122098',\n",
-       " 'num_visits': 2,\n",
-       " 'conditions': ['5715',\n",
-       "  '452',\n",
-       "  '07070',\n",
-       "  '2800',\n",
-       "  '45620',\n",
-       "  '0389',\n",
-       "  '99592',\n",
-       "  '78552',\n",
-       "  '51881',\n",
-       "  '5724',\n",
-       "  '5849',\n",
-       "  '5859',\n",
-       "  '78951',\n",
-       "  '5723',\n",
-       "  '5715',\n",
-       "  '5849',\n",
-       "  '07070',\n",
-       "  '2800',\n",
-       "  '2875',\n",
-       "  '78951',\n",
-       "  '45620',\n",
-       "  '5723',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
+       "{'patient_id': '10006',\n",
+       " 'visit_id': '142345',\n",
+       " 'num_visits': 1,\n",
+       " 'conditions': ['99591',\n",
+       "  '99662',\n",
+       "  '5672',\n",
+       "  '40391',\n",
+       "  '42731',\n",
+       "  '4280',\n",
+       "  '4241',\n",
+       "  '4240',\n",
+       "  '2874',\n",
+       "  '03819',\n",
+       "  '7850',\n",
+       "  'E8791',\n",
+       "  'V090',\n",
+       "  '56211',\n",
+       "  '28529',\n",
+       "  '25000',\n",
+       "  'V5867',\n",
+       "  'E9342',\n",
+       "  '41401',\n",
+       "  '2749',\n",
        "  '0',\n",
        "  '0',\n",
        "  '0',\n",
@@ -1299,36 +1292,26 @@
        "  '0',\n",
        "  '0',\n",
        "  '0'],\n",
-       " 'conditions_text': ['Cirrhosis of liver without mention of alcohol',\n",
-       "  'Portal vein thrombosis',\n",
-       "  'Unspecified viral hepatitis C without hepatic coma',\n",
-       "  'Iron deficiency anemia secondary to blood loss (chronic)',\n",
-       "  'Esophageal varices in diseases classified elsewhere, with bleeding',\n",
-       "  'Unspecified septicemia',\n",
-       "  'Severe sepsis',\n",
-       "  'Septic shock',\n",
-       "  'Acute respiratory failure',\n",
-       "  'Hepatorenal syndrome',\n",
-       "  'Acute kidney failure, unspecified',\n",
-       "  'Chronic kidney disease, unspecified',\n",
-       "  'Malignant ascites',\n",
-       "  'Portal hypertension',\n",
-       "  'Cirrhosis of liver without mention of alcohol',\n",
-       "  'Acute kidney failure, unspecified',\n",
-       "  'Unspecified viral hepatitis C without hepatic coma',\n",
-       "  'Iron deficiency anemia secondary to blood loss (chronic)',\n",
-       "  'Thrombocytopenia, unspecified',\n",
-       "  'Malignant ascites',\n",
-       "  'Esophageal varices in diseases classified elsewhere, with bleeding',\n",
-       "  'Portal hypertension',\n",
-       "  '',\n",
-       "  '',\n",
-       "  '',\n",
-       "  '',\n",
-       "  '',\n",
-       "  '',\n",
+       " 'conditions_text': ['Sepsis',\n",
+       "  'Infection and inflammatory reaction due to other vascular device, implant, and graft',\n",
        "  '',\n",
+       "  'Hypertensive chronic kidney disease, unspecified, with chronic kidney disease stage V or end stage renal disease',\n",
+       "  'Atrial fibrillation',\n",
+       "  'Congestive heart failure, unspecified',\n",
+       "  'Aortic valve disorders',\n",
+       "  'Mitral valve disorders',\n",
        "  '',\n",
+       "  'Other staphylococcal septicemia',\n",
+       "  'Tachycardia, unspecified',\n",
+       "  'Kidney dialysis as the cause of abnormal reaction of patient, or of later complication, without mention of misadventure at time of procedure',\n",
+       "  'Infection with microorganisms resistant to penicillins',\n",
+       "  'Diverticulitis of colon (without mention of hemorrhage)',\n",
+       "  'Anemia of other chronic disease',\n",
+       "  'Diabetes mellitus without mention of complication, type II or unspecified type, not stated as uncontrolled',\n",
+       "  'Long-term (current) use of insulin',\n",
+       "  'Anticoagulants causing adverse effects in therapeutic use',\n",
+       "  'Coronary atherosclerosis of native coronary artery',\n",
+       "  'Gout, unspecified',\n",
        "  '',\n",
        "  '',\n",
        "  '',\n",
@@ -1349,22 +1332,12 @@
        "  '',\n",
        "  '',\n",
        "  ''],\n",
-       " 'procedures': ['4513',\n",
-       "  '3891',\n",
-       "  '3893',\n",
-       "  '9672',\n",
-       "  '9915',\n",
+       " 'procedures': ['9749',\n",
+       "  '5491',\n",
        "  '3895',\n",
        "  '3995',\n",
-       "  '5491',\n",
-       "  '4233',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
-       "  '0',\n",
+       "  '3893',\n",
+       "  '9907',\n",
        "  '0',\n",
        "  '0',\n",
        "  '0',\n",
@@ -1399,22 +1372,12 @@
        "  '0',\n",
        "  '0',\n",
        "  '0'],\n",
-       " 'procedures_text': ['Other endoscopy of small intestine',\n",
-       "  'Arterial catheterization',\n",
-       "  'Venous catheterization, not elsewhere classified',\n",
-       "  'Continuous invasive mechanical ventilation for 96 consecutive hours or more',\n",
-       "  'Parenteral infusion of concentrated nutritional substances',\n",
+       " 'procedures_text': ['Removal of other device from thorax',\n",
+       "  'Percutaneous abdominal drainage',\n",
        "  'Venous catheterization for renal dialysis',\n",
        "  'Hemodialysis',\n",
-       "  'Percutaneous abdominal drainage',\n",
-       "  'Endoscopic excision or destruction of lesion or tissue of esophagus',\n",
-       "  '',\n",
-       "  '',\n",
-       "  '',\n",
-       "  '',\n",
-       "  '',\n",
-       "  '',\n",
-       "  '',\n",
+       "  'Venous catheterization, not elsewhere classified',\n",
+       "  'Transfusion of other serum',\n",
        "  '',\n",
        "  '',\n",
        "  '',\n",
@@ -1449,42 +1412,32 @@
        "  '',\n",
        "  '',\n",
        "  ''],\n",
-       " 'drugs': ['A12C',\n",
+       " 'drugs': ['A10B',\n",
+       "  'J01C',\n",
        "  'B05X',\n",
-       "  'V04C',\n",
+       "  'J01D',\n",
+       "  'J01X',\n",
+       "  'V06D',\n",
+       "  'V03A',\n",
+       "  'C07A',\n",
        "  'A02B',\n",
-       "  'J01F',\n",
+       "  'N02B',\n",
+       "  'A02A',\n",
+       "  'A06A',\n",
+       "  'A12C',\n",
+       "  'V04C',\n",
+       "  'A12B',\n",
+       "  'R06A',\n",
        "  'N05C',\n",
-       "  'N01A',\n",
-       "  'A12A',\n",
-       "  'H01C',\n",
        "  'B02B',\n",
-       "  'V06D',\n",
-       "  'A10A',\n",
+       "  'P01A',\n",
        "  'J01M',\n",
-       "  'C01C',\n",
-       "  'B01A',\n",
-       "  'B05B',\n",
-       "  'C05B',\n",
-       "  'J01X',\n",
-       "  'A06A',\n",
-       "  'B05A',\n",
-       "  'H02A',\n",
        "  'N02A',\n",
-       "  'A04A',\n",
-       "  'V06D',\n",
-       "  'H01C',\n",
-       "  'A02B',\n",
-       "  'J01M',\n",
-       "  'A04A',\n",
-       "  'N05A',\n",
-       "  'V04C',\n",
-       "  'B05B',\n",
-       "  'C05B',\n",
-       "  'C05A',\n",
-       "  'D04A',\n",
-       "  'B05X',\n",
-       "  'C07A',\n",
+       "  'B01A',\n",
+       "  'C09A',\n",
+       "  'C10A',\n",
+       "  '0',\n",
+       "  '0',\n",
        "  '0',\n",
        "  '0',\n",
        "  '0',\n",
@@ -1499,66 +1452,56 @@
        "  '0',\n",
        "  '0',\n",
        "  '0'],\n",
-       " 'drugs_text': ['Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n",
-       "  'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n",
-       "  'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n",
-       "  'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n",
-       "  'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n",
-       "  'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n",
-       "  'NS BASE 500mL Bag IV 0.0',\n",
-       "  'Pantoprazole Sodium MAIN 40mg Vial IV 0.0',\n",
-       "  'Erythromycin MAIN 1000 mg Vial IV 0.0',\n",
+       " 'drugs_text': ['Glipizide MAIN 10MG TAB PO 0.0',\n",
+       "  'Ampicillin Sodium MAIN 2g Vial IV 0.0',\n",
+       "  'NS BASE 50mL Bag IV 0.0',\n",
+       "  'Ceftriaxone MAIN 1g Frozen Bag IV 0.0',\n",
+       "  'Metronidazole MAIN 500mg Premix Bag IV 0.0',\n",
+       "  'Vancomycin HCl MAIN 10g Vial IV 0.0',\n",
+       "  'D5W BASE 250mL Bag IV 0.0',\n",
+       "  'Unasyn MAIN 3g Vial IV 0.0',\n",
+       "  'NS BASE 100mL Bag IV 0.0',\n",
+       "  'Sevelamer MAIN 800mg Tab PO 0.0',\n",
+       "  'Metoprolol MAIN 50mg Tab PO 0.0',\n",
+       "  'Pantoprazole MAIN 40mg Tab PO 0.0',\n",
+       "  'Acetaminophen MAIN 325mg Tab PO 0.0',\n",
+       "  'Magnesium Oxide MAIN 140mg Cap PPK PO 0.0',\n",
+       "  'Magnesium Oxide MAIN 140mg Cap PPK PO 0.0',\n",
+       "  'Magnesium Oxide MAIN 140mg Cap PPK PO 0.0',\n",
+       "  'Magnesium Sulfate MAIN 1gm/2ml vial IV 0.0',\n",
+       "  'Magnesium Sulfate MAIN 1gm/2ml vial IV 0.0',\n",
+       "  'Magnesium Sulfate MAIN 1gm/2ml vial IV 0.0',\n",
+       "  'Docusate Sodium MAIN 100mg Cap PO 0.0',\n",
+       "  'Senna MAIN 1 TAB PO 0.0',\n",
+       "  'Linezolid MAIN 600mg Tab PO 0.0',\n",
+       "  'Linezolid MAIN 600mg Tab PO 0.0',\n",
+       "  'Potassium Chloride MAIN 20mEq Packet PO 0.0',\n",
+       "  'Diphenhydramine HCl MAIN 25mg Cap PO 0.0',\n",
+       "  'Zolpidem Tartrate MAIN 5mg Tab PO 0.0',\n",
+       "  'Vancomycin HCl MAIN 1GM FROZ. BAG IV 0.0',\n",
        "  'Midazolam HCl MAIN 2mg/2mL Vial IV 0.0',\n",
-       "  'Fentanyl Citrate MAIN 100mcg/2mL Amp IV 0.0',\n",
-       "  'Calcium Gluconate MAIN 1g/10mL Vial IV 0.0',\n",
-       "  'NS BASE 150mL Bab IV 0.0',\n",
-       "  'Octreotide Acetate MAIN 50mcg/mL-1mL IV 0.0',\n",
-       "  'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n",
-       "  'Calcium Gluconate MAIN 1g/10mL Vial IV 0.0',\n",
-       "  'Phytonadione MAIN 5mg Tab PO 0.0',\n",
-       "  'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n",
-       "  'Dextrose 50% MAIN 50mL Syringe IV 0.0',\n",
-       "  'Insulin MAIN 10mL Vial IV 0.0',\n",
-       "  'Insulin MAIN 10mL Vial IV 0.0',\n",
-       "  'Calcium Chloride MAIN 1g/10mL Syringe IV 0.0',\n",
-       "  'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n",
-       "  'NS BASE 500mL Bag IV 0.0',\n",
-       "  'NS BASE 500mL Bag IV 0.0',\n",
-       "  'Ciprofloxacin MAIN 400mg PM Bag IV 0.0',\n",
-       "  'Norepinephrine MAIN 4mg/4mL Amp IV DRIP 0.0',\n",
-       "  'Heparin Flush CVL  (100 units/ml) MAIN 100Units/mL-5mL Syringe IV 0.0',\n",
-       "  'D5W BASE 100mL Bag IV DRIP 0.0',\n",
-       "  'D5W BASE 250mL Bag IV DRIP 0.0',\n",
-       "  'Octreotide Acetate MAIN 0.1 mg Amp IV DRIP 0.0',\n",
-       "  'Fentanyl Citrate MAIN 2.5mg/50mL Vial IV DRIP 0.0',\n",
-       "  'D5W BASE 250mL Bag IV DRIP 0.0',\n",
-       "  'Midazolam HCl MAIN 50mg/10mL Vial IV DRIP 0.0',\n",
-       "  'D5W BASE 100mL Bag IV DRIP 0.0',\n",
-       "  'Heparin Flush CRRT (5000 Units/mL) MAIN 5000 Units / mL- 1mL Vial IV 0.0',\n",
-       "  'NS BASE 500mL Bag IV 0.0',\n",
-       "  'Amino Acids 4.25% W/ Dextrose 5% BASE 1000ml Bag IV 0.0',\n",
-       "  'Amino Acids 4.25% W/ Dextrose 5% BASE 1000ml Bag IV 0.0',\n",
-       "  'Vancomycin HCl MAIN 1g Frozen Bag IV 0.0',\n",
-       "  'Sodium Bicarbonate MAIN 50mEq Vial IV 0.0',\n",
-       "  'Lactulose MAIN 1000ml Bottle PR 0.0',\n",
-       "  'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n",
-       "  'D5W BASE 3000 mL Bag IV 0.0',\n",
-       "  'Sodium Bicarbonate MAIN 50mEq Vial IV 0.0',\n",
-       "  'D5W BASE 1000mL Bag IV 0.0',\n",
-       "  'Insulin Human Regular MAIN 100Units/mL; 10mL Vial IV DRIP 0.0',\n",
-       "  'Insulin Human Regular MAIN 100Units/mL; 10mL Vial IV DRIP 0.0',\n",
-       "  'NS BASE 100mL Bag IV DRIP 0.0',\n",
+       "  'Midazolam HCl MAIN 2mg/2mL Vial IV 0.0',\n",
+       "  'Phytonadione MAIN 10mg/mL Amp SC 0.0',\n",
+       "  'Vancomycin HCl MAIN 1GM FROZ. BAG IV 0.0',\n",
+       "  'Metronidazole MAIN 500MG TAB PO 0.0',\n",
+       "  'Levofloxacin MAIN 250mg Tab PO 0.0',\n",
+       "  'Oxycodone-Acetaminophen MAIN 5mg/325mg Tab PO 0.0',\n",
+       "  'Alteplase (Catheter Clearance) MAIN 1mg/1mL Syringe IV 0.0',\n",
+       "  'Warfarin MAIN 2mg Tab PO 0.0',\n",
+       "  'Moexipril HCl MAIN 7.5mg Tab PO 0.0',\n",
+       "  'Sevelamer MAIN 400mg Tab PO 0.0',\n",
+       "  'Sevelamer MAIN 800mg Tab PO 0.0',\n",
        "  ''],\n",
-       " 'conditions_pad': 22,\n",
-       " 'procedures_pad': 9,\n",
-       " 'conditions_text_pad': 22,\n",
-       " 'procedures_text_pad': 9,\n",
-       " 'drugs_pad': 36,\n",
-       " 'drugs_text_pad': 49,\n",
-       " 'label': 1}"
+       " 'conditions_pad': 20,\n",
+       " 'procedures_pad': 6,\n",
+       " 'conditions_text_pad': 20,\n",
+       " 'procedures_text_pad': 6,\n",
+       " 'drugs_pad': 24,\n",
+       " 'drugs_text_pad': 39,\n",
+       " 'label': 0}"
       ]
      },
-     "execution_count": 94,
+     "execution_count": 131,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1572,7 +1515,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 97,
+   "execution_count": 132,
    "id": "70586b66-a814-4898-892b-d1bdd339ce7b",
    "metadata": {},
    "outputs": [
@@ -1580,21 +1523,343 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Generating samples for mortality_pred_task_per_patient:   0%|                    | 0/100 [00:00<?, ?it/s]\n"
+      "Generating samples for mortality_pred_task_per_patient: 100%|████████| 100/100 [00:00<00:00, 1197.22it/s]"
      ]
     },
     {
-     "ename": "NameError",
-     "evalue": "name 'global_readmission_label' is not defined",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
-      "\u001b[0;32m/var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_66272/2572961812.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmor_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmimic3base\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_task\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmortality_pred_task_per_patient\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mmor_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mmor_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msamples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/envs/cs589_dl_healthcare_assignments/lib/python3.7/site-packages/pyhealth/datasets/base_dataset.py\u001b[0m in \u001b[0;36mset_task\u001b[0;34m(self, task_fn, task_name)\u001b[0m\n\u001b[1;32m    357\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpatients\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34mf\"Generating samples for {task_name}\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    358\u001b[0m         ):\n\u001b[0;32m--> 359\u001b[0;31m             \u001b[0msamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtask_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpatient\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    360\u001b[0m         sample_dataset = SampleDataset(\n\u001b[1;32m    361\u001b[0m             \u001b[0msamples\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;32m/var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_66272/4294464053.py\u001b[0m in \u001b[0;36mmortality_pred_task_per_patient\u001b[0;34m(patient)\u001b[0m\n\u001b[1;32m    257\u001b[0m         \u001b[0;31m# \"labevents\": labevents,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    258\u001b[0m         \u001b[0;31m# \"labevents_text\": labevents_text\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 259\u001b[0;31m         \u001b[0;34m\"label\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mglobal_readmission_label\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    260\u001b[0m     }\n\u001b[1;32m    261\u001b[0m     \u001b[0;31m# sample = {\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;31mNameError\u001b[0m: name 'global_readmission_label' is not defined"
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Statistics of sample dataset:\n",
+      "\t- Dataset: MIMIC3DatasetWrapper\n",
+      "\t- Task: mortality_pred_task_per_patient\n",
+      "\t- Number of samples: 261\n",
+      "\t- Number of patients: 87\n",
+      "\t- Number of visits: 87\n",
+      "\t- Number of visits per patient: 3.0000\n",
+      "\t- num_visits:\n",
+      "\t\t- Number of num_visits per sample: 1.0000\n",
+      "\t\t- Number of unique num_visits: 4\n",
+      "\t\t- Distribution of num_visits (Top-10): [(1, 228), (2, 24), (3, 6), (14, 3)]\n",
+      "\t- conditions:\n",
+      "\t\t- Number of conditions per sample: 40.0000\n",
+      "\t\t- Number of unique conditions: 482\n",
+      "\t\t- Distribution of conditions (Top-10): [('0', 6726), ('4019', 102), ('42731', 99), ('4280', 99), ('5849', 99), ('51881', 78), ('25000', 69), ('486', 69), ('99592', 51), ('0389', 48)]\n",
+      "\t- conditions_text:\n",
+      "\t\t- Number of conditions_text per sample: 40.0000\n",
+      "\t\t- Number of unique conditions_text: 465\n",
+      "\t\t- Distribution of conditions_text (Top-10): [('', 6837), ('Unspecified essential hypertension', 102), ('Atrial fibrillation', 99), ('Congestive heart failure, unspecified', 99), ('Acute kidney failure, unspecified', 99), ('Acute respiratory failure', 78), ('Diabetes mellitus without mention of complication, type II or unspecified type, not stated as uncontrolled', 69), ('Pneumonia, organism unspecified', 69), ('Severe sepsis', 51), ('Unspecified septicemia', 48)]\n",
+      "\t- procedures:\n",
+      "\t\t- Number of procedures per sample: 40.0000\n",
+      "\t\t- Number of unique procedures: 151\n",
+      "\t\t- Distribution of procedures (Top-10): [('0', 9264), ('3893', 114), ('966', 78), ('9604', 57), ('9671', 54), ('9672', 54), ('9904', 42), ('5491', 27), ('9915', 27), ('3995', 21)]\n",
+      "\t- procedures_text:\n",
+      "\t\t- Number of procedures_text per sample: 40.0000\n",
+      "\t\t- Number of unique procedures_text: 148\n",
+      "\t\t- Distribution of procedures_text (Top-10): [('', 9273), ('Venous catheterization, not elsewhere classified', 114), ('Enteral infusion of concentrated nutritional substances', 78), ('Insertion of endotracheal tube', 57), ('Continuous invasive mechanical ventilation for less than 96 consecutive hours', 54), ('Continuous invasive mechanical ventilation for 96 consecutive hours or more', 54), ('Transfusion of packed cells', 42), ('Percutaneous abdominal drainage', 27), ('Parenteral infusion of concentrated nutritional substances', 27), ('Hemodialysis', 21)]\n",
+      "\t- drugs:\n",
+      "\t\t- Number of drugs per sample: 37.7586\n",
+      "\t\t- Number of unique drugs: 137\n",
+      "\t\t- Distribution of drugs (Top-10): [('0', 3492), ('B05X', 252), ('A02B', 225), ('B01A', 219), ('N02B', 213), ('A06A', 204), ('N02A', 195), ('V06D', 189), ('A12C', 171), ('V04C', 168)]\n",
+      "\t- drugs_text:\n",
+      "\t\t- Number of drugs_text per sample: 37.7586\n",
+      "\t\t- Number of unique drugs_text: 700\n",
+      "\t\t- Distribution of drugs_text (Top-10): [('', 864), ('Magnesium Sulfate MAIN 2 g / 50 mL Premix Bag IV 0.0', 294), ('Neutra-Phos MAIN Powder Packet PO 0.0', 273), ('Bisacodyl MAIN 5 mg Tab PO 0.0', 180), ('Furosemide MAIN 40mg/4mL Vial IV 0.0', 168), ('Bisacodyl MAIN 10mg Suppository PR 0.0', 162), ('Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0', 117), ('Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0', 111), ('0.9% Sodium Chloride BASE 1000mL Bag IV 0.0', 111), ('D5W BASE 250mL Bag IV DRIP 0.0', 105)]\n",
+      "\t- conditions_pad:\n",
+      "\t\t- Number of conditions_pad per sample: 1.0000\n",
+      "\t\t- Number of unique conditions_pad: 26\n",
+      "\t\t- Distribution of conditions_pad (Top-10): [(8, 51), (13, 21), (10, 18), (14, 18), (7, 15), (22, 12), (9, 12), (17, 12), (39, 9), (4, 9)]\n",
+      "\t- procedures_pad:\n",
+      "\t\t- Number of procedures_pad per sample: 1.0000\n",
+      "\t\t- Number of unique procedures_pad: 18\n",
+      "\t\t- Distribution of procedures_pad (Top-10): [(0, 48), (2, 42), (1, 36), (3, 24), (5, 21), (7, 18), (6, 12), (4, 12), (9, 9), (10, 9)]\n",
+      "\t- conditions_text_pad:\n",
+      "\t\t- Number of conditions_text_pad per sample: 1.0000\n",
+      "\t\t- Number of unique conditions_text_pad: 26\n",
+      "\t\t- Distribution of conditions_text_pad (Top-10): [(8, 51), (13, 21), (10, 18), (14, 18), (7, 15), (22, 12), (9, 12), (17, 12), (39, 9), (4, 9)]\n",
+      "\t- procedures_text_pad:\n",
+      "\t\t- Number of procedures_text_pad per sample: 1.0000\n",
+      "\t\t- Number of unique procedures_text_pad: 18\n",
+      "\t\t- Distribution of procedures_text_pad (Top-10): [(0, 48), (2, 42), (1, 36), (3, 24), (5, 21), (7, 18), (6, 12), (4, 12), (9, 9), (10, 9)]\n",
+      "\t- drugs_pad:\n",
+      "\t\t- Number of drugs_pad per sample: 1.0000\n",
+      "\t\t- Number of unique drugs_pad: 31\n",
+      "\t\t- Distribution of drugs_pad (Top-10): [(39, 57), (17, 18), (-1, 15), (22, 15), (14, 15), (19, 12), (25, 12), (24, 9), (12, 9), (36, 9)]\n",
+      "\t- drugs_text_pad:\n",
+      "\t\t- Number of drugs_text_pad per sample: 1.0000\n",
+      "\t\t- Number of unique drugs_text_pad: 16\n",
+      "\t\t- Distribution of drugs_text_pad (Top-10): [(39, 189), (-1, 15), (24, 9), (31, 6), (37, 6), (34, 6), (27, 3), (32, 3), (36, 3), (38, 3)]\n",
+      "\t- label:\n",
+      "\t\t- Number of label per sample: 1.0000\n",
+      "\t\t- Number of unique label: 2\n",
+      "\t\t- Distribution of label (Top-10): [(0, 150), (1, 111)]\n"
      ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'patient_id': '10006',\n",
+       " 'visit_id': '142345',\n",
+       " 'num_visits': 1,\n",
+       " 'conditions': ['99591',\n",
+       "  '99662',\n",
+       "  '5672',\n",
+       "  '40391',\n",
+       "  '42731',\n",
+       "  '4280',\n",
+       "  '4241',\n",
+       "  '4240',\n",
+       "  '2874',\n",
+       "  '03819',\n",
+       "  '7850',\n",
+       "  'E8791',\n",
+       "  'V090',\n",
+       "  '56211',\n",
+       "  '28529',\n",
+       "  '25000',\n",
+       "  'V5867',\n",
+       "  'E9342',\n",
+       "  '41401',\n",
+       "  '2749',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0'],\n",
+       " 'conditions_text': ['Sepsis',\n",
+       "  'Infection and inflammatory reaction due to other vascular device, implant, and graft',\n",
+       "  '',\n",
+       "  'Hypertensive chronic kidney disease, unspecified, with chronic kidney disease stage V or end stage renal disease',\n",
+       "  'Atrial fibrillation',\n",
+       "  'Congestive heart failure, unspecified',\n",
+       "  'Aortic valve disorders',\n",
+       "  'Mitral valve disorders',\n",
+       "  '',\n",
+       "  'Other staphylococcal septicemia',\n",
+       "  'Tachycardia, unspecified',\n",
+       "  'Kidney dialysis as the cause of abnormal reaction of patient, or of later complication, without mention of misadventure at time of procedure',\n",
+       "  'Infection with microorganisms resistant to penicillins',\n",
+       "  'Diverticulitis of colon (without mention of hemorrhage)',\n",
+       "  'Anemia of other chronic disease',\n",
+       "  'Diabetes mellitus without mention of complication, type II or unspecified type, not stated as uncontrolled',\n",
+       "  'Long-term (current) use of insulin',\n",
+       "  'Anticoagulants causing adverse effects in therapeutic use',\n",
+       "  'Coronary atherosclerosis of native coronary artery',\n",
+       "  'Gout, unspecified',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  ''],\n",
+       " 'procedures': ['9749',\n",
+       "  '5491',\n",
+       "  '3895',\n",
+       "  '3995',\n",
+       "  '3893',\n",
+       "  '9907',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0'],\n",
+       " 'procedures_text': ['Removal of other device from thorax',\n",
+       "  'Percutaneous abdominal drainage',\n",
+       "  'Venous catheterization for renal dialysis',\n",
+       "  'Hemodialysis',\n",
+       "  'Venous catheterization, not elsewhere classified',\n",
+       "  'Transfusion of other serum',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  '',\n",
+       "  ''],\n",
+       " 'drugs': ['A10B',\n",
+       "  'J01C',\n",
+       "  'B05X',\n",
+       "  'J01D',\n",
+       "  'J01X',\n",
+       "  'V06D',\n",
+       "  'V03A',\n",
+       "  'C07A',\n",
+       "  'A02B',\n",
+       "  'N02B',\n",
+       "  'A02A',\n",
+       "  'A06A',\n",
+       "  'A12C',\n",
+       "  'V04C',\n",
+       "  'A12B',\n",
+       "  'R06A',\n",
+       "  'N05C',\n",
+       "  'B02B',\n",
+       "  'P01A',\n",
+       "  'J01M',\n",
+       "  'N02A',\n",
+       "  'B01A',\n",
+       "  'C09A',\n",
+       "  'C10A',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0',\n",
+       "  '0'],\n",
+       " 'drugs_text': ['Glipizide MAIN 10MG TAB PO 0.0',\n",
+       "  'Ampicillin Sodium MAIN 2g Vial IV 0.0',\n",
+       "  'NS BASE 50mL Bag IV 0.0',\n",
+       "  'Ceftriaxone MAIN 1g Frozen Bag IV 0.0',\n",
+       "  'Metronidazole MAIN 500mg Premix Bag IV 0.0',\n",
+       "  'Vancomycin HCl MAIN 10g Vial IV 0.0',\n",
+       "  'D5W BASE 250mL Bag IV 0.0',\n",
+       "  'Unasyn MAIN 3g Vial IV 0.0',\n",
+       "  'NS BASE 100mL Bag IV 0.0',\n",
+       "  'Sevelamer MAIN 800mg Tab PO 0.0',\n",
+       "  'Metoprolol MAIN 50mg Tab PO 0.0',\n",
+       "  'Pantoprazole MAIN 40mg Tab PO 0.0',\n",
+       "  'Acetaminophen MAIN 325mg Tab PO 0.0',\n",
+       "  'Magnesium Oxide MAIN 140mg Cap PPK PO 0.0',\n",
+       "  'Magnesium Oxide MAIN 140mg Cap PPK PO 0.0',\n",
+       "  'Magnesium Oxide MAIN 140mg Cap PPK PO 0.0',\n",
+       "  'Magnesium Sulfate MAIN 1gm/2ml vial IV 0.0',\n",
+       "  'Magnesium Sulfate MAIN 1gm/2ml vial IV 0.0',\n",
+       "  'Magnesium Sulfate MAIN 1gm/2ml vial IV 0.0',\n",
+       "  'Docusate Sodium MAIN 100mg Cap PO 0.0',\n",
+       "  'Senna MAIN 1 TAB PO 0.0',\n",
+       "  'Linezolid MAIN 600mg Tab PO 0.0',\n",
+       "  'Linezolid MAIN 600mg Tab PO 0.0',\n",
+       "  'Potassium Chloride MAIN 20mEq Packet PO 0.0',\n",
+       "  'Diphenhydramine HCl MAIN 25mg Cap PO 0.0',\n",
+       "  'Zolpidem Tartrate MAIN 5mg Tab PO 0.0',\n",
+       "  'Vancomycin HCl MAIN 1GM FROZ. BAG IV 0.0',\n",
+       "  'Midazolam HCl MAIN 2mg/2mL Vial IV 0.0',\n",
+       "  'Midazolam HCl MAIN 2mg/2mL Vial IV 0.0',\n",
+       "  'Phytonadione MAIN 10mg/mL Amp SC 0.0',\n",
+       "  'Vancomycin HCl MAIN 1GM FROZ. BAG IV 0.0',\n",
+       "  'Metronidazole MAIN 500MG TAB PO 0.0',\n",
+       "  'Levofloxacin MAIN 250mg Tab PO 0.0',\n",
+       "  'Oxycodone-Acetaminophen MAIN 5mg/325mg Tab PO 0.0',\n",
+       "  'Alteplase (Catheter Clearance) MAIN 1mg/1mL Syringe IV 0.0',\n",
+       "  'Warfarin MAIN 2mg Tab PO 0.0',\n",
+       "  'Moexipril HCl MAIN 7.5mg Tab PO 0.0',\n",
+       "  'Sevelamer MAIN 400mg Tab PO 0.0',\n",
+       "  'Sevelamer MAIN 800mg Tab PO 0.0',\n",
+       "  ''],\n",
+       " 'conditions_pad': 20,\n",
+       " 'procedures_pad': 6,\n",
+       " 'conditions_text_pad': 20,\n",
+       " 'procedures_text_pad': 6,\n",
+       " 'drugs_pad': 24,\n",
+       " 'drugs_text_pad': 39,\n",
+       " 'label': 0}"
+      ]
+     },
+     "execution_count": 132,
+     "metadata": {},
+     "output_type": "execute_result"
     }
    ],
    "source": [
@@ -1763,12 +2028,21 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 134,
    "id": "924c8e3c-0482-4688-b2d6-d83d490060a9",
    "metadata": {
     "tags": []
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "ename": "IndentationError",
+     "evalue": "unindent does not match any outer indentation level (<tokenize>, line 184)",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;36m  File \u001b[0;32m\"<tokenize>\"\u001b[0;36m, line \u001b[0;32m184\u001b[0m\n\u001b[0;31m    samples = data\u001b[0m\n\u001b[0m    ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m unindent does not match any outer indentation level\n"
+     ]
+    }
+   ],
    "source": [
     "# def per_visit_collate_fn(data):\n",
     "#     \"\"\"\n",
@@ -1872,8 +2146,11 @@
     "#     return x, masks, rev_x, rev_masks, y\n",
     "\n",
     "\n",
-    "def per_visit_collate_fn(data):\n",
+    "def bert_per_patient_collate_function(data):\n",
     "    \"\"\"\n",
+    "    Collates a tensor of (batch_size, seq_len, bert_emb_len) i.e. (32, <variable_per_patient>, 768)\n",
+    "    Need to pad the second dimension to max(<variable_per_patient>)\n",
+    "    \n",
     "    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis\n",
     "        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation\n",
     "        is stored in `mask`.\n",
@@ -1882,11 +2159,11 @@
     "        data: a list of samples fetched from `CustomDataset`\n",
     "        \n",
     "    Outputs:\n",
-    "        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long\n",
-    "        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool\n",
+    "        x: a tensor of shape (batch_size, max #conditions, max embedding size) of type torch.float\n",
+    "        masks: a tensor of shape (batch_size, max #conditions, max embedding_size) of type torch.bool\n",
     "        rev_x: same as x but in reversed time. This will be used in our RNN model for masking \n",
     "        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking\n",
-    "        y: a tensor of shape (# patiens) of type torch.float\n",
+    "        y: a tensor of shape (batch_size) of type torch.float\n",
     "        \n",
     "    Note that you can obtains the list of diagnosis codes and the list of hf labels\n",
     "        using: `sequences, labels = zip(*data)`\n",
@@ -1927,14 +2204,66 @@
     "    \n",
     "    return x, masks, rev_x, rev_masks, y\n",
     "\n",
-    "def event_seq_collate_fn(data):\n",
-    "    # TBD(botelho3)\n",
-    "    pass"
+    "def code_emb_per_patient_collate_function(data):\n",
+    "    \"\"\"\n",
+    "    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis\n",
+    "    sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation\n",
+    "    is stored in `mask`.\n",
+    "                                   \n",
+    "    \n",
+    "    Arguments:\n",
+    "        data: a list of samples fetched from `CustomDataset`\n",
+    "        \n",
+    "    Outputs:\n",
+    "    x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long\n",
+    "    masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool\n",
+    "    rev_x: same as x but in reversed time. This will be used in our RNN model for masking\n",
+    "    rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking\n",
+    "    y: a tensor of shape (# patiens) of type torch.float\n",
+    "        \n",
+    "    Note that you can obtains the list of diagnosis codes and the list of hf labels\\n\",\n",
+    "    using: `sequences, labels = zip(*data)`\\n\",\n",
+    "    \"\"\"\n",
+    "    \n",
+    "    # sequences, labels = zip(*data)\n",
+    "    samples = data\n",
+    "\n",
+    "    y = torch.tensor(labels, dtype=torch.float)\n",
+    "    \n",
+    "    num_patients = len(sequences)\n",
+    "    num_visits = [len(patient) for patient in sequences]\n",
+    "    num_codes = [len(visit) for patient in sequences for visit in patient]\n",
+    "\n",
+    "    max_num_visits = max(num_visits)\n",
+    "    max_num_codes = max(num_codes)\n",
+    "    \n",
+    "    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n",
+    "    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n",
+    "    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n",
+    "    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n",
+    "    for i_patient, patient in enumerate(sequences):\n",
+    "        kMaxVisits = len(patient)\n",
+    "        for j_visit, visit in enumerate(patient):\n",
+    "            \"\"\"\n",
+    "            TODO: update `x`, `rev_x`, `masks`, and `rev_masks`\n",
+    "            \"\"\"\n",
+    "            l = len(visit)\n",
+    "            x[i_patient, j_visit, 0:l] = torch.LongTensor(visit)\n",
+    "            rev_x[i_patient, kMaxVisits-1-j_visit, 0:l] = torch.LongTensor(visit)\n",
+    "            masks[i_patient, j_visit, 0:l] = 1\n",
+    "            rev_masks[i_patient, kMaxVisits-1-j_visit, 0:l] = 1\n",
+    "        # print(f\"------ v: {kMaxVisits} ------\")\n",
+    "        # print(x[i_patient, :, ])\n",
+    "        # print(rev_x[i_patient, :, ])\n",
+    "        # print(masks[i_patient, :, ])\n",
+    "        # print(rev_masks[i_patient, :, ])\n",
+    "    \n",
+    "    return x, masks, rev_x, rev_masks, y"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": 110,
    "id": "f6afd1f5-0f3b-4f4b-9c85-f74bb367ef88",
    "metadata": {},
    "outputs": [],
@@ -1967,6 +2296,10 @@
     "        tokenized_text = self.tokenizer.tokenize(text)\n",
     "        return tokenized_text\n",
     "    \n",
+    "    def _get_embeddings_of_sentences_with_mask(sample, field, pad):\n",
+    "        return self._get_embeddings_of_sentences(field[:pad])\n",
+    "        \n",
+    "    \n",
     "    def _get_embeddings_of_sentences(self, sentences: List[str]) -> torch.tensor:\n",
     "        # tokenized_sentences = [self.tokenizer.tokenize(t, padding=True) for t in sentences]\n",
     "        # print(batch_enc)\n",
@@ -1991,18 +2324,30 @@
     "        condition_embeddings = [None] * len(sample['conditions_text'])\n",
     "        procedure_embeddings = [None] * len(sample['procedures_text'])\n",
     "        drug_embeddings = [None] * len(sample['drugs_text'])\n",
-    "        \n",
-    "        condition_embeddings, condition_attentions = self._get_embeddings_of_sentences(sample['conditions_text'])\n",
+    "       \n",
+    "        if sample.get('conditions_text_pad'):\n",
+    "            condition_embeddings, condition_attentions = self._get_embeddings_of_sentences_with_mask(\n",
+    "                sample['conditions_text'], sample['conditions_text_pad'])\n",
+    "        else:\n",
+    "            condition_embeddings, condition_attentions = self._get_embeddings_of_sentences(sample['conditions_text'])\n",
     "        # for cond in sample['conditions_text']:\n",
     "        #     #bert_model(cond)\n",
     "        #     condition_embeddings.append(cond)\n",
     "        #     pass\n",
-    "        procedure_embeddings, procedure_attentions = self._get_embeddings_of_sentences(sample['procedures_text'])\n",
+    "        if sample.get('procedures_text_pad'):\n",
+    "            procedure_embeddings, procedure_attentions = self._get_embeddings_of_sentences_with_mask(\n",
+    "                sample['conditions_text'])(sample['procedures_text'], sample['procedures_text_pad'])\n",
+    "        else:\n",
+    "            procedure_embeddings, procedure_attentions = self._get_embeddings_of_sentences(sample['procedures_text'])\n",
     "        # for proc in sample['procedures_text']:\n",
     "        #     #bert_model(proc)\n",
     "        #     procedure_embeddings.append(proc)\n",
     "        #     pass\n",
-    "        drug_embeddings, drug_attentions = self._get_embeddings_of_sentences(sample['drugs_text'])\n",
+    "        if (sample.get('drugs_text_pad')):\n",
+    "            drug_embeddings, drug_attentions = self._get_embeddings_of_sentences_with_mask(\n",
+    "                sample['drugs_text'], sample['drugs_text_pad'])\n",
+    "        else:\n",
+    "            drug_embeddings, drug_attentions = self._get_embeddings_of_sentences(sample['drugs_text'])\n",
     "        # for drug in sample['drugs_text']:\n",
     "        #     #bert_model(drug)\n",
     "        #     procedure_embeddings.append(drug)\n",
@@ -2016,13 +2361,18 @@
     "              f'de: {drug_embeddings.shape}')\n",
     "       \n",
     "        stacked = torch.cat([condition_embeddings, procedure_embeddings, drug_embeddings], dim=0)\n",
-    "        summed_embeddings = torch.sum(stacked, dim=0, keepdim=True)  # sum across rows\n",
-    "        # normalize the final row (across columns)\n",
-    "        summed_embeddings = torch.nn.functional.normalize(summed_embeddings, dim=1)\n",
     "        \n",
-    "        print(summed_embeddings.shape)\n",
-    "        assert(summed_embeddings.shape == (1, self.bert_config.hidden_size))\n",
-    "        return summed_embeddings\n",
+    "        # We don't want to normalize here because we need a sequence of embeddings for each sample. \n",
+    "        # normalize the final row (across columns)\n",
+    "        # summed_embeddings = torch.sum(stacked, dim=0, keepdim=True)  # sum across rows\n",
+    "        # summed_embeddings = torch.nn.functional.normalize(summed_embeddings, dim=1)\n",
+    "        # print(summed_embeddings.shape)\n",
+    "        # assert(summed_embeddings.shape == (1, self.bert_config.hidden_size))\n",
+    "       \n",
+    "        # The 1st dimension is seq length. The second dimension is embedding length of each sentence.\n",
+    "        assert(stacked.shape[-1] == self.bert_config.hidden_size)\n",
+    "        assert(len(stacked.shape) == 2)\n",
+    "        return stacked\n",
     "            \n",
     "#         image, landmarks = sample['image'], sample['landmarks']\n",
     "\n",
@@ -2076,7 +2426,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 27,
+   "execution_count": 111,
    "id": "1056b60a-5861-4e47-b694-891053cc8470",
    "metadata": {},
    "outputs": [
@@ -2087,7 +2437,7 @@
       "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']\n",
       "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
       "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
-      "Generating samples for mortality_pred_task: 100%|███████████████████| 100/100 [00:00<00:00, 11527.25it/s]\n"
+      "Generating samples for mortality_pred_task: 100%|███████████████████| 100/100 [00:00<00:00, 14394.62it/s]\n"
      ]
     }
    ],
@@ -2156,10 +2506,55 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 112,
    "id": "035fccee-d5da-428c-a54a-0678a2d19b5e",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "ce: torch.Size([22, 768])pe: torch.Size([1, 768])de: torch.Size([36, 768])\n",
+      "ce: torch.Size([15, 768])pe: torch.Size([8, 768])de: torch.Size([422, 768])\n",
+      "ce: torch.Size([17, 768])pe: torch.Size([1, 768])de: torch.Size([47, 768])\n",
+      "ce: torch.Size([9, 768])pe: torch.Size([1, 768])de: torch.Size([62, 768])\n",
+      "ce: torch.Size([37, 768])pe: torch.Size([3, 768])de: torch.Size([512, 768])\n",
+      "ce: torch.Size([8, 768])pe: torch.Size([4, 768])de: torch.Size([354, 768])\n",
+      "ce: torch.Size([7, 768])pe: torch.Size([1, 768])de: torch.Size([337, 768])\n",
+      "ce: torch.Size([17, 768])pe: torch.Size([4, 768])de: torch.Size([74, 768])\n",
+      "ce: torch.Size([9, 768])pe: torch.Size([5, 768])de: torch.Size([607, 768])\n",
+      "ce: torch.Size([11, 768])pe: torch.Size([2, 768])de: torch.Size([483, 768])\n",
+      "ce: torch.Size([14, 768])pe: torch.Size([1, 768])de: torch.Size([199, 768])\n",
+      "ce: torch.Size([24, 768])pe: torch.Size([2, 768])de: torch.Size([62, 768])\n",
+      "ce: torch.Size([18, 768])pe: torch.Size([3, 768])de: torch.Size([182, 768])\n",
+      "ce: torch.Size([13, 768])pe: torch.Size([7, 768])de: torch.Size([679, 768])\n",
+      "ce: torch.Size([11, 768])pe: torch.Size([9, 768])de: torch.Size([52, 768])\n",
+      "ce: torch.Size([12, 768])pe: torch.Size([5, 768])de: torch.Size([63, 768])\n",
+      "ce: torch.Size([19, 768])pe: torch.Size([1, 768])de: torch.Size([9, 768])\n",
+      "ce: torch.Size([8, 768])pe: torch.Size([1, 768])de: torch.Size([67, 768])\n",
+      "ce: torch.Size([5, 768])pe: torch.Size([3, 768])de: torch.Size([21, 768])\n",
+      "ce: torch.Size([21, 768])pe: torch.Size([1, 768])de: torch.Size([71, 768])\n",
+      "ce: torch.Size([4, 768])pe: torch.Size([4, 768])de: torch.Size([32, 768])\n",
+      "ce: torch.Size([18, 768])pe: torch.Size([6, 768])de: torch.Size([221, 768])\n",
+      "ce: torch.Size([21, 768])pe: torch.Size([2, 768])de: torch.Size([307, 768])\n"
+     ]
+    },
+    {
+     "ename": "NameError",
+     "evalue": "name 'labels' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m/var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_66272/3599354066.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;31m# loader = DataLoader(mort_, batch_size=10, collate_fn=collate_fn)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0mloader_iter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmort_train_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloader_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/envs/cs589_dl_healthcare_assignments/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    679\u001b[0m                 \u001b[0;31m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    680\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[call-arg]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 681\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    682\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    683\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/envs/cs589_dl_healthcare_assignments/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    719\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    720\u001b[0m         \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 721\u001b[0;31m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    722\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    723\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/envs/cs589_dl_healthcare_assignments/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     50\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;32m/var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_66272/2662870878.py\u001b[0m in \u001b[0;36mper_visit_collate_fn\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m    124\u001b[0m     \u001b[0msamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m     \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    127\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    128\u001b[0m     \u001b[0mnum_patients\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msequences\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'labels' is not defined"
+     ]
+    }
+   ],
    "source": [
     "# Verify DataLoader properties.\n",
     "\n",