diff --git a/src/cs598_dlh_final_project.ipynb b/src/cs598_dlh_final_project.ipynb index 1fd42899e3d4388ec276c846b0222b943732f323..a6dbd06a4f67e938fc102d339773ff801faccbd6 100644 --- a/src/cs598_dlh_final_project.ipynb +++ b/src/cs598_dlh_final_project.ipynb @@ -2193,7 +2193,63 @@ " # print(masks[i_patient, :, ])\n", " # print(rev_masks[i_patient, :, ])\n", " \n", - " return x, masks, rev_x, rev_masks, y\n" + " return x, masks, rev_x, rev_masks, y\n", + "\n", + "def code_emb_per_patient_collate_function(data):\n", + " \"\"\"\n", + " TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis\n", + " sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation\n", + " is stored in `mask`.\n", + " \n", + " \n", + " Arguments:\n", + " data: a list of samples fetched from `CustomDataset`\n", + " \n", + " Outputs:\n", + " x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long\n", + " masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool\n", + " rev_x: same as x but in reversed time. This will be used in our RNN model for masking\n", + " rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking\n", + " y: a tensor of shape (# patiens) of type torch.float\n", + " \n", + " Note that you can obtains the list of diagnosis codes and the list of hf labels\\n\",\n", + " using: `sequences, labels = zip(*data)`\\n\",\n", + " \"\"\"\n", + " \n", + " # sequences, labels = zip(*data)\n", + " samples = data\n", + " \n", + " y = torch.tensor(labels, dtype=torch.float)\n", + " \n", + " num_patients = len(sequences)\n", + " num_visits = [len(patient) for patient in sequences]\n", + " num_codes = [len(visit) for patient in sequences for visit in patient]\n", + "\n", + " max_num_visits = max(num_visits)\n", + " max_num_codes = max(num_codes)\n", + " \n", + " x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n", + " rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n", + " masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n", + " rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n", + " for i_patient, patient in enumerate(sequences):\n", + " kMaxVisits = len(patient)\n", + " for j_visit, visit in enumerate(patient):\n", + " \"\"\"\n", + " TODO: update `x`, `rev_x`, `masks`, and `rev_masks`\n", + " \"\"\"\n", + " l = len(visit)\n", + " x[i_patient, j_visit, 0:l] = torch.LongTensor(visit)\n", + " rev_x[i_patient, kMaxVisits-1-j_visit, 0:l] = torch.LongTensor(visit)\n", + " masks[i_patient, j_visit, 0:l] = 1\n", + " rev_masks[i_patient, kMaxVisits-1-j_visit, 0:l] = 1\n", + " # print(f\"------ v: {kMaxVisits} ------\")\n", + " # print(x[i_patient, :, ])\n", + " # print(rev_x[i_patient, :, ])\n", + " # print(masks[i_patient, :, ])\n", + " # print(rev_masks[i_patient, :, ])\n", + " \n", + " return x, masks, rev_x, rev_masks, y" ] }, { @@ -2207,7 +2263,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 141, "id": "f6afd1f5-0f3b-4f4b-9c85-f74bb367ef88", "metadata": {}, "outputs": [], @@ -2370,7 +2426,7 @@ }, { "cell_type": "code", - "execution_count": 138, + "execution_count": 140, "id": "1056b60a-5861-4e47-b694-891053cc8470", "metadata": {}, "outputs": [ @@ -2381,7 +2437,7 @@ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Generating samples for mortality_pred_task: 100%|█████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 14196.80it/s]\n" + "Generating samples for mortality_pred_task: 100%|█████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 12923.05it/s]\n" ] } ], @@ -2528,6 +2584,73 @@ "#### CodeEmb Loader" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "992467f7-d834-4d4a-872a-714389fa203c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "\n", + "def code_emb_per_patient_collate_function(data):\n", + " \"\"\"\n", + " TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis\n", + " sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation\n", + " is stored in `mask`.\n", + " \n", + " \n", + " Arguments:\n", + " data: a list of samples fetched from `CustomDataset`\n", + " \n", + " Outputs:\n", + " x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long\n", + " masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool\n", + " rev_x: same as x but in reversed time. This will be used in our RNN model for masking\n", + " rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking\n", + " y: a tensor of shape (# patiens) of type torch.float\n", + " \n", + " Note that you can obtains the list of diagnosis codes and the list of hf labels\\n\",\n", + " using: `sequences, labels = zip(*data)`\\n\",\n", + " \"\"\"\n", + " \n", + " # sequences, labels = zip(*data)\n", + " samples = data\n", + " \n", + " y = torch.tensor(labels, dtype=torch.float)\n", + " \n", + " num_patients = len(sequences)\n", + " num_visits = [len(patient) for patient in sequences]\n", + " num_codes = [len(visit) for patient in sequences for visit in patient]\n", + "\n", + " max_num_visits = max(num_visits)\n", + " max_num_codes = max(num_codes)\n", + " \n", + " x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n", + " rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n", + " masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n", + " rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n", + " for i_patient, patient in enumerate(sequences):\n", + " kMaxVisits = len(patient)\n", + " for j_visit, visit in enumerate(patient):\n", + " \"\"\"\n", + " TODO: update `x`, `rev_x`, `masks`, and `rev_masks`\n", + " \"\"\"\n", + " l = len(visit)\n", + " x[i_patient, j_visit, 0:l] = torch.LongTensor(visit)\n", + " rev_x[i_patient, kMaxVisits-1-j_visit, 0:l] = torch.LongTensor(visit)\n", + " masks[i_patient, j_visit, 0:l] = 1\n", + " rev_masks[i_patient, kMaxVisits-1-j_visit, 0:l] = 1\n", + " # print(f\"------ v: {kMaxVisits} ------\")\n", + " # print(x[i_patient, :, ])\n", + " # print(rev_x[i_patient, :, ])\n", + " # print(masks[i_patient, :, ])\n", + " # print(rev_masks[i_patient, :, ])\n", + " \n", + " return x, masks, rev_x, rev_masks, y" + ] + }, { "cell_type": "code", "execution_count": null, @@ -2535,10 +2658,6 @@ "metadata": {}, "outputs": [], "source": [ - "# Create the transform that will take each sample (visit) in the dataset\n", - "# and convert the text description of the visit into a single embedding.\n", - "bert_xform = BertTextEmbedTransform(None, EMBEDDING_DIM_)\n", - "\n", "# Mortality Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)\n", "# Add a transform to convert the visit into an embedding.\n", "mortality_cemb_ds = mimic3base.set_task(task_fn=mortality_pred_task)\n", @@ -2553,65 +2672,22 @@ " mort_cemb_train_ds,\n", " batch_size=BATCH_SIZE_,\n", " shuffle=SHUFFLE_,\n", - " collate_fn=bert_per_patient_collate_function,\n", + " collate_fn=code_emb_per_patient_collate_function,\n", ")\n", "mort_cemb_val_loader = DataLoader(\n", " mort_cemb_val_ds,\n", " batch_size=BATCH_SIZE_,\n", " shuffle=SHUFFLE_,\n", - " collate_fn=bert_per_patient_collate_function,\n", + " collate_fn=code_emb_per_patient_collate_function,\n", ")\n", "mort_cemb_test_loader = DataLoader(\n", " mort_cemb_test_ds,\n", " batch_size=BATCH_SIZE_,\n", " shuffle=SHUFFLE_,\n", - " collate_fn=bert_per_patient_collate_function,\n", + " collate_fn=code_emb_per_patient_collate_function,\n", ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "9436cb7d-e302-460c-b35d-3025dff4b999", - "metadata": {}, - "outputs": [], - "source": [ - "# Create the transform that will take each sample (visit) in the dataset\n", - "# and convert the text description of the visit into a single embedding.\n", - "bert_xform = BertTextEmbedTransform(None, EMBEDDING_DIM_)\n", - "\n", - "# Mortality Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)\n", - "# Add a transform to convert the visit into an embedding.\n", - "mortality_dataset = mimic3base.set_task(task_fn=mortality_pred_task)\n", - "mortality_dataset = TextEmbedDataset(mortality_dataset, transform=bert_xform)\n", - "mort_train_ds, mort_val_ds, mort_test_ds = split_by_patient(mortality_dataset, [0.8, 0.1, 0.1])\n", - "\n", - "# create dataloaders (torch.data.DataLoader)\n", - "# mort_train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True, collate_fn)\n", - "# mort_val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)\n", - "# mort_test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)\n", - "\n", - "mort_train_loader = DataLoader(\n", - " mort_train_ds,\n", - " batch_size=BATCH_SIZE_,\n", - " shuffle=SHUFFLE_,\n", - " collate_fn=bert_per_patient_collate_function,\n", - ")\n", - "mort_val_loader = DataLoader(\n", - " mort_val_ds,\n", - " batch_size=BATCH_SIZE_,\n", - " shuffle=SHUFFLE_,\n", - " collate_fn=bert_per_patient_collate_function,\n", - ")\n", - "mort_test_loader = DataLoader(\n", - " mort_test_ds,\n", - " batch_size=BATCH_SIZE_,\n", - " shuffle=SHUFFLE_,\n", - " collate_fn=bert_per_patient_collate_function,\n", - ")\n", - "\n" - ] - }, { "cell_type": "markdown", "id": "ebe2c2d9-c161-4b9d-8182-86384faea929", @@ -2772,6 +2848,14 @@ "source": [ "## Train DescEmb" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc9335e3-0da2-47f3-9a90-53531ad277b0", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {