From 2d2a37d30953af6a2053b3e26b4f2641c6e9836b Mon Sep 17 00:00:00 2001 From: aBotel <admbotel@uwaterloo.ca> Date: Sat, 15 Apr 2023 02:24:54 -0700 Subject: [PATCH] Add some additional text LUT code and DataLoader skeleton. --- src/cs598_dlh_final_project.ipynb | 410 ++++++++++++++++++++++++++++-- 1 file changed, 383 insertions(+), 27 deletions(-) diff --git a/src/cs598_dlh_final_project.ipynb b/src/cs598_dlh_final_project.ipynb index 05528cc..e3a3d55 100644 --- a/src/cs598_dlh_final_project.ipynb +++ b/src/cs598_dlh_final_project.ipynb @@ -230,7 +230,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 119, "id": "b881e548-4d27-4725-8c47-b2612157929e", "metadata": {}, "outputs": [], @@ -250,6 +250,7 @@ " def __init__(self, *args, **kwargs):\n", " self._valid_text_tables = [\"D_ICD_DIAGNOSES\", \"D_ITEMS\", \"D_ICD_PROCEDURES\", \"D_LABITEMS\"]\n", " self._text_descriptions = {x: {} for x in self._valid_text_tables}\n", + " self._text_luts = {x: {} for x in self._valid_text_tables}\n", " super().__init__(*args, **kwargs)\n", " \n", " def get_all_tables(self) -> List[str]: \n", @@ -258,6 +259,12 @@ " def get_text_dict(self, table_name: str) -> Dict[str, Dict[Any, Any]]:\n", " return self._text_descriptions.get(table_name)\n", " \n", + " def set_text_lut(self, table_name: str, lut: Dict[Any, Any]) -> None:\n", + " self._text_luts[table_name] = lut\n", + " \n", + " def get_text_lut(self, table_name: str) -> Dict[Any, Any]:\n", + " return self._text_luts[table_name]\n", + " \n", " # Note the name has to match the table name exactly.\n", " # See https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/datasets/mimic3.py#L71.\n", " def parse_d_icd_diagnoses(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: \n", @@ -408,7 +415,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 124, "id": "427a42a6-441e-438c-96f4-b38ba82fd192", "metadata": {}, "outputs": [ @@ -416,7 +423,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Parsing PATIENTS and ADMISSIONS: 100%|███████████████████████████████████████████████| 100/100 [00:00<00:00, 616.21it/s]\n" + "Parsing PATIENTS and ADMISSIONS: 100%|███████████████████████████████████████████████| 100/100 [00:00<00:00, 594.73it/s]\n" ] }, { @@ -433,11 +440,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "Parsing DIAGNOSES_ICD: 100%|████████████████████████████████████████████████████████| 129/129 [00:00<00:00, 5045.37it/s]\n", - "Parsing PROCEDURES_ICD: 100%|███████████████████████████████████████████████████████| 113/113 [00:00<00:00, 6525.72it/s]\n", - "Parsing PRESCRIPTIONS: 100%|█████████████████████████████████████████████████████████| 122/122 [00:00<00:00, 136.63it/s]\n", - "Parsing LABEVENTS: 100%|██████████████████████████████████████████████████████████████| 129/129 [00:05<00:00, 25.78it/s]\n", - "Mapping codes: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 244.02it/s]\n" + "Parsing DIAGNOSES_ICD: 100%|████████████████████████████████████████████████████████| 129/129 [00:00<00:00, 5116.36it/s]\n", + "Parsing PROCEDURES_ICD: 100%|███████████████████████████████████████████████████████| 113/113 [00:00<00:00, 6220.39it/s]\n", + "Parsing PRESCRIPTIONS: 100%|█████████████████████████████████████████████████████████| 122/122 [00:00<00:00, 137.11it/s]\n", + "Mapping codes: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 253.20it/s]\n" ] }, { @@ -457,7 +463,6 @@ "\t- Number of events per visit in DIAGNOSES_ICD: 13.6512\n", "\t- Number of events per visit in PROCEDURES_ICD: 3.9225\n", "\t- Number of events per visit in PRESCRIPTIONS: 115.6667\n", - "\t- Number of events per visit in LABEVENTS: 479.1628\n", "\n", "\n", "dataset.patients: patient_id -> <Patient>\n", @@ -479,15 +484,15 @@ ], "source": [ "mimic3base = MIMIC3DatasetWrapper(\n", - " #root=\"https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/\",\n", + " # root=\"https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/\",\n", " root=os.path.join(os.getcwd(), DATA_DIR_),\n", " tables=[\"D_ICD_DIAGNOSES\", \"D_ICD_PROCEDURES\", \"D_ITEMS\", \"D_LABITEMS\",\n", - " \"DIAGNOSES_ICD\", \"PROCEDURES_ICD\", \"PRESCRIPTIONS\", \"LABEVENTS\"],\n", + " \"DIAGNOSES_ICD\", \"PROCEDURES_ICD\", \"PRESCRIPTIONS\"], # \"LABEVENTS\"],\n", " # map all NDC codes to ATC 3-rd level codes in these tables\n", " # See https://en.wikipedia.org/wiki/Anatomical_Therapeutic_Chemical_Classification_System.\n", " code_mapping={\"NDC\": (\"ATC\", {\"target_kwargs\": {\"level\": 3}})},\n", " # Slow\n", - " refresh_cache=True,\n", + " refresh_cache=False,\n", ")\n", "\n", "mimic3base.stat()\n", @@ -496,7 +501,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 125, "id": "4cef4476-6bc8-4791-b246-4c329a80a2e4", "metadata": {}, "outputs": [ @@ -504,7 +509,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "dict_keys(['D_ICD_DIAGNOSES', 'D_ITEMS', 'D_ICD_PROCEDURES', 'D_LABITEMS'])\n", + "['D_ICD_DIAGNOSES', 'D_ITEMS', 'D_ICD_PROCEDURES', 'D_LABITEMS']\n", "Table: D_ICD_DIAGNOSES\n", "[['0010', 'Cholera d/t vib cholerae', 'Cholera due to vibrio cholerae'], ['0011', 'Cholera d/t vib el tor', 'Cholera due to vibrio cholerae el tor'], ['0019', 'Cholera NOS', 'Cholera, unspecified'], ['0020', 'Typhoid fever', 'Typhoid fever'], ['0021', 'Paratyphoid fever a', 'Paratyphoid fever A']]\n", "\n", @@ -536,7 +541,14 @@ " d = mimic3base.get_text_dict(t)\n", " print(f\"Table: {t}\")\n", " print(d['data'][:5])\n", - " print('\\n\\n')\n" + " print('\\n\\n')\n", + "\n", + "\n", + "for t in table_names:\n", + " d = mimic3base.get_text_dict(t)\n", + " d = d['data']\n", + " lut = {record[0]: record[1:] for record in d}\n", + " mimic3base.set_text_lut(t, lut)" ] }, { @@ -551,7 +563,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 115, "id": "08692aad-db67-487b-9b26-dbde0ab6adce", "metadata": {}, "outputs": [], @@ -573,6 +585,7 @@ " for i in range(len(patient) - 1):\n", "\n", " # visit and next_visit are both <pyhealth.data.Visit> objects\n", + " # there are no vists.attr_dict keys\n", " visit = patient[i]\n", " next_visit = patient[i + 1]\n", "\n", @@ -585,11 +598,20 @@ " # step 2: get code-based feature information\n", " conditions = visit.get_code_list(table=\"DIAGNOSES_ICD\")\n", " procedures = visit.get_code_list(table=\"PROCEDURES_ICD\")\n", - " drugs = visit.get_code_list(table=\"PRESCRIPTIONS\")\n", - " labevents = visit.get_code_list(table=\"LABEVENTS\")\n", + " # drugs = visit.get_code_list(table=\"PRESCRIPTIONS\")\n", + " # TODO(botelho3) - add this datasource back in once we have full MIMIC-III dataset.\n", + " # labevents = visit.get_code_list(table=\"LABEVENTS\")\n", "\n", " # step 3: exclusion criteria: visits without condition, procedure, or drug\n", - " if len(conditions) * len(procedures) * len(drugs) == 0: continue\n", + " if len(conditions) * len(procedures) * len(labevents) == 0: continue\n", + " \n", + " # step 3.5: build text lists from the ICD codes\n", + " d_diag = mimic3base._text_luts.get(\"D_DIAGNOSES_ICD\").get()\n", + " d_proc = mimic3base._text_luts.get(\"D_PROCEDURES_ICD\").get()\n", + " # Index 0 is shortname, index 1 is longname.\n", + " conditions_text = [d_diag.get(cond)[1] for cond in conditions]\n", + " procedutes_text = [d_proc.get(proc)[1] for proc in procedures]\n", + " # labevents_text =\n", " \n", " # step 4: assemble the samples\n", " samples.append(\n", @@ -599,7 +621,11 @@ " # the following keys can be the \"feature_keys\" or \"label_key\" for initializing downstream ML model\n", " \"conditions\": conditions,\n", " \"procedures\": procedures,\n", - " \"drugs\": drugs,\n", + " \"conditions_text\": conditions_text,\n", + " \"procedures_text\": procedures_text,\n", + " # \"drugs\": drugs,\n", + " # \"labevents\": labevents,\n", + " # \"labevents_text\": labevents_text\n", " \"label\": mortality_label,\n", " }\n", " )\n", @@ -628,6 +654,7 @@ " conditions = visit.get_code_list(table=\"DIAGNOSES_ICD\")\n", " procedures = visit.get_code_list(table=\"PROCEDURES_ICD\")\n", " drugs = visit.get_code_list(table=\"PRESCRIPTIONS\")\n", + " labevents = visit.get_code_list(table=\"LABEVENTS\")\n", "\n", " # step 3: exclusion criteria: visits without condition, procedure, or drug\n", " if len(conditions) * len(procedures) * len(drugs) == 0: continue\n", @@ -653,13 +680,224 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 105, "id": "96d047ff-6412-462c-9294-a5ba2d3bdba2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating samples for readmission_pred_task: 44%|███████████████■| 44/100 [00:00<00:00, 386.68it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "visit num_events: 598\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 252\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 414\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 276\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 314\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 774\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 1723\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 635\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 856\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 288\n", + "visit tables: ['DIAGNOSES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 252\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 393\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 641\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 310\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 361\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 178\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 385\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating samples for readmission_pred_task: 100%|██████████████████████████████████| 100/100 [00:00<00:00, 328.61it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "visit num_events: 908\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 498\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 1178\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 822\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 185\n", + "visit tables: ['DIAGNOSES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 1496\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 177\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 1036\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 130\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 369\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 268\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "visit num_events: 136\n", + "visit tables: ['DIAGNOSES_ICD', 'PROCEDURES_ICD', 'PRESCRIPTIONS', 'LABEVENTS']\n", + "visit attrs\n", + "{}\n", + "Statistics of sample dataset:\n", + "\t- Dataset: MIMIC3DatasetWrapper\n", + "\t- Task: readmission_pred_task\n", + "\t- Number of samples: 27\n", + "\t- Number of patients: 13\n", + "\t- Number of visits: 27\n", + "\t- Number of visits per patient: 2.0769\n", + "\t- conditions:\n", + "\t\t- Number of conditions per sample: 1.0000\n", + "\t\t- Number of unique conditions: 17\n", + "\t\t- Distribution of conditions (Top-10): [('0389', 8), ('486', 3), ('03849', 2), ('5715', 1), ('41071', 1), ('5750', 1), ('51884', 1), ('20280', 1), ('53642', 1), ('51881', 1)]\n", + "\t- procedures:\n", + "\t\t- Number of procedures per sample: 3.4074\n", + "\t\t- Number of unique procedures: 42\n", + "\t\t- Distribution of procedures (Top-10): [('966', 15), ('3893', 14), ('9671', 5), ('9672', 4), ('9604', 4), ('4513', 3), ('3891', 3), ('9915', 3), ('9904', 3), ('3995', 2)]\n", + "\t- drugs:\n", + "\t\t- Number of drugs per sample: 40.7037\n", + "\t\t- Number of unique drugs: 112\n", + "\t\t- Distribution of drugs (Top-10): [('B05X', 27), ('B01A', 25), ('N02A', 25), ('A02B', 24), ('A06A', 24), ('N02B', 23), ('J01X', 21), ('A12C', 20), ('A12B', 20), ('V04C', 19)]\n", + "\t- label:\n", + "\t\t- Number of label per sample: 1.0000\n", + "\t\t- Number of unique label: 2\n", + "\t\t- Distribution of label (Top-10): [(0, 26), (1, 1)]\n" + ] + }, + { + "data": { + "text/plain": [ + "{'visit_id': '122098',\n", + " 'patient_id': '10059',\n", + " 'conditions': '5715',\n", + " 'procedures': ['4513',\n", + " '3891',\n", + " '3893',\n", + " '9672',\n", + " '9915',\n", + " '3895',\n", + " '3995',\n", + " '5491'],\n", + " 'drugs': ['A12C',\n", + " 'B05X',\n", + " 'V04C',\n", + " 'A02B',\n", + " 'J01F',\n", + " 'N05C',\n", + " 'N01A',\n", + " 'A12A',\n", + " 'H01C',\n", + " 'B02B',\n", + " 'V06D',\n", + " 'A10A',\n", + " 'J01M',\n", + " 'C01C',\n", + " 'B01A',\n", + " 'B05B',\n", + " 'C05B',\n", + " 'J01X',\n", + " 'A06A',\n", + " 'B05A',\n", + " 'H02A',\n", + " 'N02A',\n", + " 'A04A'],\n", + " 'label': 0}" + ] + }, + "execution_count": 105, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "re_dataset = dataset.set_task(readmission_pred_task)\n", - "re_dataset.stat()" + "readm_dataset = mimic3base.set_task(readmission_pred_task)\n", + "readm_dataset.stat()\n", + "readm_dataset.samples[0]" ] }, { @@ -669,8 +907,72 @@ "metadata": {}, "outputs": [], "source": [ - "mor_dataset = dataset.set_task(mortality_pred_task)\n", - "mor_dataset.stat()" + "mor_dataset = mimic3base.set_task(mortality_pred_task)\n", + "mor_dataset.stat()\n", + "mor_dataset.samples[0]" + ] + }, + { + "cell_type": "markdown", + "id": "66a0281a-f95b-4991-abd5-5824fc37c520", + "metadata": {}, + "source": [ + "### DataLoaders and Collate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "924c8e3c-0482-4688-b2d6-d83d490060a9", + "metadata": {}, + "outputs": [], + "source": [ + "# def collate_fn(data):\n", + "# \"\"\"\n", + "# Collate the the list of samples into batches. For each patient, you need to pad the diagnosis\n", + "# sequences to the sample shape (max # visits, len(freq_codes)). The padding infomation\n", + "# is stored in `mask`.\n", + " \n", + "# Arguments:\n", + "# data: a list of samples fetched from `CustomDataset`\n", + " \n", + "# Outputs:\n", + "# x: a tensor of shape (# patiens, max # visits, len(freq_codes)) of type torch.float\n", + "# masks: a tensor of shape (# patiens, max # visits) of type torch.bool\n", + "# y: a tensor of shape (# patiens) of type torch.float\n", + " \n", + "# Note that you can obtains the list of diagnosis codes and the list of hf labels\n", + "# using: `sequences, labels = zip(*data)`\n", + "# \"\"\"\n", + "\n", + "# sequences, labels = zip(*data)\n", + "\n", + "# y = torch.tensor(labels, dtype=torch.float)\n", + " \n", + "# num_patients = len(sequences)\n", + "# num_visits = [len(patient) for patient in sequences]\n", + "# max_num_visits = max(num_visits)\n", + " \n", + "# x = torch.zeros((num_patients, max_num_visits, len(freq_codes)), dtype=torch.float) \n", + "# for i_patient, patient in enumerate(sequences):\n", + "# kMaxVisits = len(patient)\n", + "# for j_visit, visit in enumerate(patient):\n", + "# l = len(visit)\n", + "# for code in visit:\n", + "# \"\"\"\n", + "# TODO: 1. check if code is in freq_codes;\n", + "# 2. obtain the code index using code2idx;\n", + "# 3. set the correspoindg element in x to 1.\n", + "# \"\"\"\n", + "# try:\n", + "# idx = code2idx[code]\n", + "# x[i_patient, j_visit, idx] = 1\n", + "# except KeyError as e:\n", + "# pass\n", + " \n", + "# masks = torch.sum(x, dim=-1) > 0\n", + " \n", + "# return x, masks, y" ] }, { @@ -683,7 +985,61 @@ "# create dataloaders (torch.data.DataLoader)\n", "train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)\n", "val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)\n", - "test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)\n" + "test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)\n", + "\n", + "# from torch.utils.data import DataLoader\n", + "\n", + "# def load_data(train_dataset, val_dataset, collate_fn):\n", + " \n", + "# '''\n", + "# TODO: Implement this function to return the data loader for train and validation dataset. \n", + "# Set batchsize to 32. Set `shuffle=True` only for train dataloader.\n", + " \n", + "# Arguments:\n", + "# train dataset: train dataset of type `CustomDataset`\n", + "# val dataset: validation dataset of type `CustomDataset`\n", + "# collate_fn: collate function\n", + " \n", + "# Outputs:\n", + "# train_loader, val_loader: train and validation dataloaders\n", + " \n", + "# Note that you need to pass the collate function to the data loader `collate_fn()`.\n", + "# '''\n", + " \n", + "# batch_size = 32\n", + "# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)\n", + "# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)\n", + " \n", + "\n", + "\n", + "# return train_loader, val_loader\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "035fccee-d5da-428c-a54a-0678a2d19b5e", + "metadata": {}, + "outputs": [], + "source": [ + "# Verify DataLoader properties.\n", + "\n", + "# from torch.utils.data import DataLoader\n", + "\n", + "# loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)\n", + "# loader_iter = iter(loader)\n", + "# x, masks, y = next(loader_iter)\n", + "\n", + "# assert x.dtype == torch.float\n", + "# assert y.dtype == torch.float\n", + "# assert masks.dtype == torch.bool\n", + "\n", + "# assert x.shape == (10, 3, 105)\n", + "# assert y.shape == (10,)\n", + "# assert masks.shape == (10, 3)\n", + "\n", + "# assert x[0][0].sum() == 9\n", + "# assert masks[0].sum() == 2" ] }, { -- GitLab