diff --git a/src/cs598_dlh_final_project.ipynb b/src/cs598_dlh_final_project.ipynb
index 1fd42899e3d4388ec276c846b0222b943732f323..a6dbd06a4f67e938fc102d339773ff801faccbd6 100644
--- a/src/cs598_dlh_final_project.ipynb
+++ b/src/cs598_dlh_final_project.ipynb
@@ -2193,7 +2193,63 @@
     "        # print(masks[i_patient, :, ])\n",
     "        # print(rev_masks[i_patient, :, ])\n",
     "    \n",
-    "    return x, masks, rev_x, rev_masks, y\n"
+    "    return x, masks, rev_x, rev_masks, y\n",
+    "\n",
+    "def code_emb_per_patient_collate_function(data):\n",
+    "    \"\"\"\n",
+    "    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis\n",
+    "    sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation\n",
+    "    is stored in `mask`.\n",
+    "                                   \n",
+    "    \n",
+    "    Arguments:\n",
+    "        data: a list of samples fetched from `CustomDataset`\n",
+    "        \n",
+    "    Outputs:\n",
+    "    x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long\n",
+    "    masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool\n",
+    "    rev_x: same as x but in reversed time. This will be used in our RNN model for masking\n",
+    "    rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking\n",
+    "    y: a tensor of shape (# patiens) of type torch.float\n",
+    "        \n",
+    "    Note that you can obtains the list of diagnosis codes and the list of hf labels\\n\",\n",
+    "    using: `sequences, labels = zip(*data)`\\n\",\n",
+    "    \"\"\"\n",
+    "    \n",
+    "    # sequences, labels = zip(*data)\n",
+    "    samples = data\n",
+    "    \n",
+    "    y = torch.tensor(labels, dtype=torch.float)\n",
+    "    \n",
+    "    num_patients = len(sequences)\n",
+    "    num_visits = [len(patient) for patient in sequences]\n",
+    "    num_codes = [len(visit) for patient in sequences for visit in patient]\n",
+    "\n",
+    "    max_num_visits = max(num_visits)\n",
+    "    max_num_codes = max(num_codes)\n",
+    "    \n",
+    "    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n",
+    "    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n",
+    "    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n",
+    "    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n",
+    "    for i_patient, patient in enumerate(sequences):\n",
+    "        kMaxVisits = len(patient)\n",
+    "        for j_visit, visit in enumerate(patient):\n",
+    "            \"\"\"\n",
+    "            TODO: update `x`, `rev_x`, `masks`, and `rev_masks`\n",
+    "            \"\"\"\n",
+    "            l = len(visit)\n",
+    "            x[i_patient, j_visit, 0:l] = torch.LongTensor(visit)\n",
+    "            rev_x[i_patient, kMaxVisits-1-j_visit, 0:l] = torch.LongTensor(visit)\n",
+    "            masks[i_patient, j_visit, 0:l] = 1\n",
+    "            rev_masks[i_patient, kMaxVisits-1-j_visit, 0:l] = 1\n",
+    "        # print(f\"------ v: {kMaxVisits} ------\")\n",
+    "        # print(x[i_patient, :, ])\n",
+    "        # print(rev_x[i_patient, :, ])\n",
+    "        # print(masks[i_patient, :, ])\n",
+    "        # print(rev_masks[i_patient, :, ])\n",
+    "    \n",
+    "    return x, masks, rev_x, rev_masks, y"
    ]
   },
   {
@@ -2207,7 +2263,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 110,
+   "execution_count": 141,
    "id": "f6afd1f5-0f3b-4f4b-9c85-f74bb367ef88",
    "metadata": {},
    "outputs": [],
@@ -2370,7 +2426,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 138,
+   "execution_count": 140,
    "id": "1056b60a-5861-4e47-b694-891053cc8470",
    "metadata": {},
    "outputs": [
@@ -2381,7 +2437,7 @@
       "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']\n",
       "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
       "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
-      "Generating samples for mortality_pred_task: 100%|█████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 14196.80it/s]\n"
+      "Generating samples for mortality_pred_task: 100%|█████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 12923.05it/s]\n"
      ]
     }
    ],
@@ -2528,6 +2584,73 @@
     "#### CodeEmb Loader"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "992467f7-d834-4d4a-872a-714389fa203c",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "\n",
+    "def code_emb_per_patient_collate_function(data):\n",
+    "    \"\"\"\n",
+    "    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis\n",
+    "    sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation\n",
+    "    is stored in `mask`.\n",
+    "                                   \n",
+    "    \n",
+    "    Arguments:\n",
+    "        data: a list of samples fetched from `CustomDataset`\n",
+    "        \n",
+    "    Outputs:\n",
+    "    x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long\n",
+    "    masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool\n",
+    "    rev_x: same as x but in reversed time. This will be used in our RNN model for masking\n",
+    "    rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking\n",
+    "    y: a tensor of shape (# patiens) of type torch.float\n",
+    "        \n",
+    "    Note that you can obtains the list of diagnosis codes and the list of hf labels\\n\",\n",
+    "    using: `sequences, labels = zip(*data)`\\n\",\n",
+    "    \"\"\"\n",
+    "    \n",
+    "    # sequences, labels = zip(*data)\n",
+    "    samples = data\n",
+    "    \n",
+    "    y = torch.tensor(labels, dtype=torch.float)\n",
+    "    \n",
+    "    num_patients = len(sequences)\n",
+    "    num_visits = [len(patient) for patient in sequences]\n",
+    "    num_codes = [len(visit) for patient in sequences for visit in patient]\n",
+    "\n",
+    "    max_num_visits = max(num_visits)\n",
+    "    max_num_codes = max(num_codes)\n",
+    "    \n",
+    "    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n",
+    "    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n",
+    "    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n",
+    "    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n",
+    "    for i_patient, patient in enumerate(sequences):\n",
+    "        kMaxVisits = len(patient)\n",
+    "        for j_visit, visit in enumerate(patient):\n",
+    "            \"\"\"\n",
+    "            TODO: update `x`, `rev_x`, `masks`, and `rev_masks`\n",
+    "            \"\"\"\n",
+    "            l = len(visit)\n",
+    "            x[i_patient, j_visit, 0:l] = torch.LongTensor(visit)\n",
+    "            rev_x[i_patient, kMaxVisits-1-j_visit, 0:l] = torch.LongTensor(visit)\n",
+    "            masks[i_patient, j_visit, 0:l] = 1\n",
+    "            rev_masks[i_patient, kMaxVisits-1-j_visit, 0:l] = 1\n",
+    "        # print(f\"------ v: {kMaxVisits} ------\")\n",
+    "        # print(x[i_patient, :, ])\n",
+    "        # print(rev_x[i_patient, :, ])\n",
+    "        # print(masks[i_patient, :, ])\n",
+    "        # print(rev_masks[i_patient, :, ])\n",
+    "    \n",
+    "    return x, masks, rev_x, rev_masks, y"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -2535,10 +2658,6 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# Create the transform that will take each sample (visit) in the dataset\n",
-    "# and convert the text description of the visit into a single embedding.\n",
-    "bert_xform = BertTextEmbedTransform(None, EMBEDDING_DIM_)\n",
-    "\n",
     "# Mortality Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)\n",
     "# Add a transform to convert the visit into an embedding.\n",
     "mortality_cemb_ds = mimic3base.set_task(task_fn=mortality_pred_task)\n",
@@ -2553,65 +2672,22 @@
     "    mort_cemb_train_ds,\n",
     "    batch_size=BATCH_SIZE_,\n",
     "    shuffle=SHUFFLE_,\n",
-    "    collate_fn=bert_per_patient_collate_function,\n",
+    "    collate_fn=code_emb_per_patient_collate_function,\n",
     ")\n",
     "mort_cemb_val_loader = DataLoader(\n",
     "    mort_cemb_val_ds,\n",
     "    batch_size=BATCH_SIZE_,\n",
     "    shuffle=SHUFFLE_,\n",
-    "    collate_fn=bert_per_patient_collate_function,\n",
+    "    collate_fn=code_emb_per_patient_collate_function,\n",
     ")\n",
     "mort_cemb_test_loader = DataLoader(\n",
     "    mort_cemb_test_ds,\n",
     "    batch_size=BATCH_SIZE_,\n",
     "    shuffle=SHUFFLE_,\n",
-    "    collate_fn=bert_per_patient_collate_function,\n",
+    "    collate_fn=code_emb_per_patient_collate_function,\n",
     ")"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "9436cb7d-e302-460c-b35d-3025dff4b999",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Create the transform that will take each sample (visit) in the dataset\n",
-    "# and convert the text description of the visit into a single embedding.\n",
-    "bert_xform = BertTextEmbedTransform(None, EMBEDDING_DIM_)\n",
-    "\n",
-    "# Mortality Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)\n",
-    "# Add a transform to convert the visit into an embedding.\n",
-    "mortality_dataset = mimic3base.set_task(task_fn=mortality_pred_task)\n",
-    "mortality_dataset = TextEmbedDataset(mortality_dataset, transform=bert_xform)\n",
-    "mort_train_ds, mort_val_ds, mort_test_ds = split_by_patient(mortality_dataset, [0.8, 0.1, 0.1])\n",
-    "\n",
-    "# create dataloaders (torch.data.DataLoader)\n",
-    "# mort_train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True, collate_fn)\n",
-    "# mort_val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)\n",
-    "# mort_test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)\n",
-    "\n",
-    "mort_train_loader = DataLoader(\n",
-    "    mort_train_ds,\n",
-    "    batch_size=BATCH_SIZE_,\n",
-    "    shuffle=SHUFFLE_,\n",
-    "    collate_fn=bert_per_patient_collate_function,\n",
-    ")\n",
-    "mort_val_loader = DataLoader(\n",
-    "    mort_val_ds,\n",
-    "    batch_size=BATCH_SIZE_,\n",
-    "    shuffle=SHUFFLE_,\n",
-    "    collate_fn=bert_per_patient_collate_function,\n",
-    ")\n",
-    "mort_test_loader = DataLoader(\n",
-    "    mort_test_ds,\n",
-    "    batch_size=BATCH_SIZE_,\n",
-    "    shuffle=SHUFFLE_,\n",
-    "    collate_fn=bert_per_patient_collate_function,\n",
-    ")\n",
-    "\n"
-   ]
-  },
   {
    "cell_type": "markdown",
    "id": "ebe2c2d9-c161-4b9d-8182-86384faea929",
@@ -2772,6 +2848,14 @@
    "source": [
     "## Train DescEmb"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "fc9335e3-0da2-47f3-9a90-53531ad277b0",
+   "metadata": {},
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {