diff --git a/install.sh b/install.sh index fa9e65620a15678471a4f475ccecbb97a23c0e20..710e2d41cfae92f1a49c1432d708da8721b8994d 100755 --- a/install.sh +++ b/install.sh @@ -6,8 +6,9 @@ # Make sure the BERT model is available in PyTorch #conda config --add channels conda-forge #conda install --yes -c conda-forge pytorch-pretrained-bert +#conda install -c huggingface transformers -pip install pytorch-transformers +pip install transformers pip install pyhealth pip install termcolor pip install pandarallel diff --git a/src/cs598_dlh_final_project.ipynb b/src/cs598_dlh_final_project.ipynb index 7e3e84f6e4b93e9816d14ff528ac85c45be2b23e..f67d1ccd07e03659a21281530247badb8691b11e 100644 --- a/src/cs598_dlh_final_project.ipynb +++ b/src/cs598_dlh_final_project.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 21, "id": "5da368cd-3045-4ec0-86fb-2ce15b5d1b92", "metadata": {}, "outputs": [ @@ -26,6 +26,8 @@ "source": [ "# General includes.\n", "import os\n", + "import random\n", + "import math\n", "import itertools\n", "import functools\n", "#from termcolor import colored, cprint\n", @@ -54,13 +56,13 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 2, "id": "95e714b1-ca5d-4d7b-9ace-6b4372adfd9f", "metadata": {}, "outputs": [], "source": [ "# Model imports \n", - "from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM" + "from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig" ] }, { @@ -73,14 +75,23 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 22, "id": "ebef3528-0396-459a-8112-b272089f5d67", "metadata": {}, "outputs": [], "source": [ "USE_GPU_ = False\n", "GPU_STRING_ = 'cuda'\n", - "DATA_DIR_ = '../data_input_path/mimic'" + "DATA_DIR_ = '../data_input_path/mimic'\n", + "BATCH_SIZE_ = 32\n", + "EMBEDDING_DIM_ = 264 # BERT requires a multiple of 12\n", + "SHUFFLE_ = True\n", + "# Set seed for reproducibility.\n", + "seed = 90210\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "torch.manual_seed(seed)\n", + "os.environ[\"PYTHONHASHSEED\"] = str(seed)" ] }, { @@ -109,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 4, "id": "607ef1d3-caa8-4db8-b664-86172bea936d", "metadata": {}, "outputs": [ @@ -151,7 +162,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 5, "id": "dcfb3443-8dfa-4149-904b-04fc24d7a33e", "metadata": {}, "outputs": [], @@ -440,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "id": "427a42a6-441e-438c-96f4-b38ba82fd192", "metadata": { "tags": [] @@ -450,7 +461,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Parsing PATIENTS and ADMISSIONS: 100%|███████████████████████████████████████████████| 100/100 [00:00<00:00, 593.18it/s]\n" + "Parsing PATIENTS and ADMISSIONS: 100%|████████████████████████████████| 100/100 [00:00<00:00, 587.64it/s]\n" ] }, { @@ -467,14 +478,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "Parsing DIAGNOSES_ICD: 100%|████████████████████████████████████████████████████████| 129/129 [00:00<00:00, 3274.00it/s]\n", - "Parsing PROCEDURES_ICD: 100%|███████████████████████████████████████████████████████| 113/113 [00:00<00:00, 4710.08it/s]\n" + "Parsing DIAGNOSES_ICD: 100%|█████████████████████████████████████████| 129/129 [00:00<00:00, 5849.54it/s]\n", + "Parsing PROCEDURES_ICD: 100%|████████████████████████████████████████| 113/113 [00:00<00:00, 5903.28it/s]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "26f454c50c844918b91d6526df73a672", + "model_id": "ce1dd4f4d61f4a57bac6a6a072c8be69", "version_major": 2, "version_minor": 0 }, @@ -489,7 +500,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Mapping codes: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 196.96it/s]\n" + "Mapping codes: 100%|██████████████████████████████████████████████████| 100/100 [00:00<00:00, 160.72it/s]" ] }, { @@ -526,6 +537,13 @@ " - other event-level info\n", "\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] } ], "source": [ @@ -547,7 +565,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 7, "id": "0e2d58de-bf9a-4b57-977f-6aa3ac75e84f", "metadata": {}, "outputs": [], @@ -571,7 +589,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 8, "id": "4cef4476-6bc8-4791-b246-4c329a80a2e4", "metadata": {}, "outputs": [ @@ -655,7 +673,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 9, "id": "08692aad-db67-487b-9b26-dbde0ab6adce", "metadata": {}, "outputs": [], @@ -807,10 +825,213 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "96d047ff-6412-462c-9294-a5ba2d3bdba2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating samples for readmission_pred_task: 100%|█████████████████| 100/100 [00:00<00:00, 10082.46it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Statistics of sample dataset:\n", + "\t- Dataset: MIMIC3DatasetWrapper\n", + "\t- Task: readmission_pred_task\n", + "\t- Number of samples: 27\n", + "\t- Number of patients: 13\n", + "\t- Number of visits: 27\n", + "\t- Number of visits per patient: 2.0769\n", + "\t- conditions:\n", + "\t\t- Number of conditions per sample: 15.2222\n", + "\t\t- Number of unique conditions: 178\n", + "\t\t- Distribution of conditions (Top-10): [('4019', 17), ('5849', 15), ('2449', 13), ('42731', 13), ('2724', 12), ('99592', 11), ('0389', 10), ('V5861', 10), ('70703', 10), ('3572', 10)]\n", + "\t- procedures:\n", + "\t\t- Number of procedures per sample: 3.4074\n", + "\t\t- Number of unique procedures: 42\n", + "\t\t- Distribution of procedures (Top-10): [('966', 15), ('3893', 14), ('9671', 5), ('9672', 4), ('9604', 4), ('4513', 3), ('3891', 3), ('9915', 3), ('9904', 3), ('3995', 2)]\n", + "\t- conditions_text:\n", + "\t\t- Number of conditions_text per sample: 15.2222\n", + "\t\t- Number of unique conditions_text: 176\n", + "\t\t- Distribution of conditions_text (Top-10): [('Unspecified essential hypertension', 17), ('Acute kidney failure, unspecified', 15), ('Unspecified acquired hypothyroidism', 13), ('Atrial fibrillation', 13), ('Other and unspecified hyperlipidemia', 12), ('Severe sepsis', 11), ('Unspecified septicemia', 10), ('Long-term (current) use of anticoagulants', 10), ('Pressure ulcer, lower back', 10), ('Polyneuropathy in diabetes', 10)]\n", + "\t- procedures_text:\n", + "\t\t- Number of procedures_text per sample: 3.4074\n", + "\t\t- Number of unique procedures_text: 42\n", + "\t\t- Distribution of procedures_text (Top-10): [('Enteral infusion of concentrated nutritional substances', 15), ('Venous catheterization, not elsewhere classified', 14), ('Continuous invasive mechanical ventilation for less than 96 consecutive hours', 5), ('Continuous invasive mechanical ventilation for 96 consecutive hours or more', 4), ('Insertion of endotracheal tube', 4), ('Other endoscopy of small intestine', 3), ('Arterial catheterization', 3), ('Parenteral infusion of concentrated nutritional substances', 3), ('Transfusion of packed cells', 3), ('Hemodialysis', 2)]\n", + "\t- drugs:\n", + "\t\t- Number of drugs per sample: 40.7037\n", + "\t\t- Number of unique drugs: 112\n", + "\t\t- Distribution of drugs (Top-10): [('B05X', 27), ('B01A', 25), ('N02A', 25), ('A02B', 24), ('A06A', 24), ('N02B', 23), ('J01X', 21), ('A12C', 20), ('A12B', 20), ('V04C', 19)]\n", + "\t- drugs_text:\n", + "\t\t- Number of drugs_text per sample: 206.9630\n", + "\t\t- Number of unique drugs_text: 432\n", + "\t\t- Distribution of drugs_text (Top-10): [('Neutra-Phos MAIN Powder Packet PO/NG 0.0', 2128), ('Neutra-Phos MAIN Powder Packet PO 0.0', 1463), ('Magnesium Sulfate MAIN 2 g / 50 mL Premix Bag IV 0.0', 129), ('Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0', 64), ('Bisacodyl MAIN 5 mg Tab PO 0.0', 48), ('Furosemide MAIN 40mg/4mL Vial IV 0.0', 46), ('0.9% Sodium Chloride BASE 1000mL Bag IV 0.0', 43), ('Insulin MAIN 100 Units / mL - 10 mL Vial SC 0.0', 35), ('0.9% Sodium Chloride BASE 500mL Bag IV 0.0', 32), ('0.9% Sodium Chloride (Mini Bag Plus) BASE 100mL Bag IV 0.0', 32)]\n", + "\t- label:\n", + "\t\t- Number of label per sample: 1.0000\n", + "\t\t- Number of unique label: 2\n", + "\t\t- Distribution of label (Top-10): [(1, 16), (0, 11)]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "{'visit_id': '122098',\n", + " 'patient_id': '10059',\n", + " 'conditions': ['5715',\n", + " '452',\n", + " '07070',\n", + " '2800',\n", + " '45620',\n", + " '0389',\n", + " '99592',\n", + " '78552',\n", + " '51881',\n", + " '5724',\n", + " '5849',\n", + " '5859',\n", + " '78951',\n", + " '5723'],\n", + " 'procedures': ['4513',\n", + " '3891',\n", + " '3893',\n", + " '9672',\n", + " '9915',\n", + " '3895',\n", + " '3995',\n", + " '5491'],\n", + " 'conditions_text': ['Cirrhosis of liver without mention of alcohol',\n", + " 'Portal vein thrombosis',\n", + " 'Unspecified viral hepatitis C without hepatic coma',\n", + " 'Iron deficiency anemia secondary to blood loss (chronic)',\n", + " 'Esophageal varices in diseases classified elsewhere, with bleeding',\n", + " 'Unspecified septicemia',\n", + " 'Severe sepsis',\n", + " 'Septic shock',\n", + " 'Acute respiratory failure',\n", + " 'Hepatorenal syndrome',\n", + " 'Acute kidney failure, unspecified',\n", + " 'Chronic kidney disease, unspecified',\n", + " 'Malignant ascites',\n", + " 'Portal hypertension'],\n", + " 'procedures_text': ['Other endoscopy of small intestine',\n", + " 'Arterial catheterization',\n", + " 'Venous catheterization, not elsewhere classified',\n", + " 'Continuous invasive mechanical ventilation for 96 consecutive hours or more',\n", + " 'Parenteral infusion of concentrated nutritional substances',\n", + " 'Venous catheterization for renal dialysis',\n", + " 'Hemodialysis',\n", + " 'Percutaneous abdominal drainage'],\n", + " 'drugs': ['A12C',\n", + " 'B05X',\n", + " 'V04C',\n", + " 'A02B',\n", + " 'J01F',\n", + " 'N05C',\n", + " 'N01A',\n", + " 'A12A',\n", + " 'H01C',\n", + " 'B02B',\n", + " 'V06D',\n", + " 'A10A',\n", + " 'J01M',\n", + " 'C01C',\n", + " 'B01A',\n", + " 'B05B',\n", + " 'C05B',\n", + " 'J01X',\n", + " 'A06A',\n", + " 'B05A',\n", + " 'H02A',\n", + " 'N02A',\n", + " 'A04A'],\n", + " 'drugs_text': ['Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'NS BASE 500mL Bag IV 0.0',\n", + " 'Pantoprazole Sodium MAIN 40mg Vial IV 0.0',\n", + " 'Erythromycin MAIN 1000 mg Vial IV 0.0',\n", + " 'Midazolam HCl MAIN 2mg/2mL Vial IV 0.0',\n", + " 'Fentanyl Citrate MAIN 100mcg/2mL Amp IV 0.0',\n", + " 'Calcium Gluconate MAIN 1g/10mL Vial IV 0.0',\n", + " 'NS BASE 150mL Bab IV 0.0',\n", + " 'Octreotide Acetate MAIN 50mcg/mL-1mL IV 0.0',\n", + " 'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n", + " 'Calcium Gluconate MAIN 1g/10mL Vial IV 0.0',\n", + " 'Phytonadione MAIN 5mg Tab PO 0.0',\n", + " 'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n", + " 'Dextrose 50% MAIN 50mL Syringe IV 0.0',\n", + " 'Insulin MAIN 10mL Vial IV 0.0',\n", + " 'Insulin MAIN 10mL Vial IV 0.0',\n", + " 'Calcium Chloride MAIN 1g/10mL Syringe IV 0.0',\n", + " 'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n", + " 'NS BASE 500mL Bag IV 0.0',\n", + " 'NS BASE 500mL Bag IV 0.0',\n", + " 'Ciprofloxacin MAIN 400mg PM Bag IV 0.0',\n", + " 'Norepinephrine MAIN 4mg/4mL Amp IV DRIP 0.0',\n", + " 'Heparin Flush CVL (100 units/ml) MAIN 100Units/mL-5mL Syringe IV 0.0',\n", + " 'D5W BASE 100mL Bag IV DRIP 0.0',\n", + " 'D5W BASE 250mL Bag IV DRIP 0.0',\n", + " 'Octreotide Acetate MAIN 0.1 mg Amp IV DRIP 0.0',\n", + " 'Fentanyl Citrate MAIN 2.5mg/50mL Vial IV DRIP 0.0',\n", + " 'D5W BASE 250mL Bag IV DRIP 0.0',\n", + " 'Midazolam HCl MAIN 50mg/10mL Vial IV DRIP 0.0',\n", + " 'D5W BASE 100mL Bag IV DRIP 0.0',\n", + " 'Heparin Flush CRRT (5000 Units/mL) MAIN 5000 Units / mL- 1mL Vial IV 0.0',\n", + " 'NS BASE 500mL Bag IV 0.0',\n", + " 'Amino Acids 4.25% W/ Dextrose 5% BASE 1000ml Bag IV 0.0',\n", + " 'Amino Acids 4.25% W/ Dextrose 5% BASE 1000ml Bag IV 0.0',\n", + " 'Vancomycin HCl MAIN 1g Frozen Bag IV 0.0',\n", + " 'Sodium Bicarbonate MAIN 50mEq Vial IV 0.0',\n", + " 'Lactulose MAIN 1000ml Bottle PR 0.0',\n", + " 'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n", + " 'D5W BASE 3000 mL Bag IV 0.0',\n", + " 'Sodium Bicarbonate MAIN 50mEq Vial IV 0.0',\n", + " 'D5W BASE 1000mL Bag IV 0.0',\n", + " 'Insulin Human Regular MAIN 100Units/mL; 10mL Vial IV DRIP 0.0',\n", + " 'Insulin Human Regular MAIN 100Units/mL; 10mL Vial IV DRIP 0.0',\n", + " 'NS BASE 100mL Bag IV DRIP 0.0',\n", + " 'Calcium Gluconate MAIN 1g/10mL Vial IV 0.0',\n", + " 'Potassium Chloride MAIN 10mEq/100mL Premix IV 0.0',\n", + " 'D5W BASE 500mL Bag IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Vancomycin HCl MAIN 1g Frozen Bag IV 0.0',\n", + " 'NS BASE 250mL Bag IV 0.0',\n", + " 'Heparin Sodium MAIN 25,000 unit Premix Bag IV 0.0',\n", + " 'Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0',\n", + " 'Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0',\n", + " 'Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0',\n", + " 'Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0',\n", + " 'Albumin 25% (12.5 g) MAIN 25%-50mL IV 0.0',\n", + " 'Hydrocortisone Na Succ. MAIN 100mg/2mL Vial IV 0.0',\n", + " 'NS BASE 500mL Bag IV 0.0',\n", + " 'Morphine Sulfate MAIN 100mg/100mL Premix IV DRIP 0.0',\n", + " 'Scopolamine Patch MAIN 1.5mg Patch TP 0.0',\n", + " 'Scopolamine Patch MAIN 1.5mg Patch TP 0.0'],\n", + " 'label': 1}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# set_task() returns a SampleEHRDataset object\n", "readm_dataset = mimic3base.set_task(readmission_pred_task)\n", @@ -820,10 +1041,213 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "70586b66-a814-4898-892b-d1bdd339ce7b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating samples for mortality_pred_task: 100%|███████████████████| 100/100 [00:00<00:00, 11274.10it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Statistics of sample dataset:\n", + "\t- Dataset: MIMIC3DatasetWrapper\n", + "\t- Task: mortality_pred_task\n", + "\t- Number of samples: 27\n", + "\t- Number of patients: 13\n", + "\t- Number of visits: 27\n", + "\t- Number of visits per patient: 2.0769\n", + "\t- conditions:\n", + "\t\t- Number of conditions per sample: 15.2222\n", + "\t\t- Number of unique conditions: 178\n", + "\t\t- Distribution of conditions (Top-10): [('4019', 17), ('5849', 15), ('2449', 13), ('42731', 13), ('2724', 12), ('99592', 11), ('0389', 10), ('V5861', 10), ('70703', 10), ('3572', 10)]\n", + "\t- procedures:\n", + "\t\t- Number of procedures per sample: 3.4074\n", + "\t\t- Number of unique procedures: 42\n", + "\t\t- Distribution of procedures (Top-10): [('966', 15), ('3893', 14), ('9671', 5), ('9672', 4), ('9604', 4), ('4513', 3), ('3891', 3), ('9915', 3), ('9904', 3), ('3995', 2)]\n", + "\t- conditions_text:\n", + "\t\t- Number of conditions_text per sample: 15.2222\n", + "\t\t- Number of unique conditions_text: 176\n", + "\t\t- Distribution of conditions_text (Top-10): [('Unspecified essential hypertension', 17), ('Acute kidney failure, unspecified', 15), ('Unspecified acquired hypothyroidism', 13), ('Atrial fibrillation', 13), ('Other and unspecified hyperlipidemia', 12), ('Severe sepsis', 11), ('Unspecified septicemia', 10), ('Long-term (current) use of anticoagulants', 10), ('Pressure ulcer, lower back', 10), ('Polyneuropathy in diabetes', 10)]\n", + "\t- procedures_text:\n", + "\t\t- Number of procedures_text per sample: 3.4074\n", + "\t\t- Number of unique procedures_text: 42\n", + "\t\t- Distribution of procedures_text (Top-10): [('Enteral infusion of concentrated nutritional substances', 15), ('Venous catheterization, not elsewhere classified', 14), ('Continuous invasive mechanical ventilation for less than 96 consecutive hours', 5), ('Continuous invasive mechanical ventilation for 96 consecutive hours or more', 4), ('Insertion of endotracheal tube', 4), ('Other endoscopy of small intestine', 3), ('Arterial catheterization', 3), ('Parenteral infusion of concentrated nutritional substances', 3), ('Transfusion of packed cells', 3), ('Hemodialysis', 2)]\n", + "\t- drugs:\n", + "\t\t- Number of drugs per sample: 40.7037\n", + "\t\t- Number of unique drugs: 112\n", + "\t\t- Distribution of drugs (Top-10): [('B05X', 27), ('B01A', 25), ('N02A', 25), ('A02B', 24), ('A06A', 24), ('N02B', 23), ('J01X', 21), ('A12C', 20), ('A12B', 20), ('V04C', 19)]\n", + "\t- drugs_text:\n", + "\t\t- Number of drugs_text per sample: 206.9630\n", + "\t\t- Number of unique drugs_text: 432\n", + "\t\t- Distribution of drugs_text (Top-10): [('Neutra-Phos MAIN Powder Packet PO/NG 0.0', 2128), ('Neutra-Phos MAIN Powder Packet PO 0.0', 1463), ('Magnesium Sulfate MAIN 2 g / 50 mL Premix Bag IV 0.0', 129), ('Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0', 64), ('Bisacodyl MAIN 5 mg Tab PO 0.0', 48), ('Furosemide MAIN 40mg/4mL Vial IV 0.0', 46), ('0.9% Sodium Chloride BASE 1000mL Bag IV 0.0', 43), ('Insulin MAIN 100 Units / mL - 10 mL Vial SC 0.0', 35), ('0.9% Sodium Chloride BASE 500mL Bag IV 0.0', 32), ('0.9% Sodium Chloride (Mini Bag Plus) BASE 100mL Bag IV 0.0', 32)]\n", + "\t- label:\n", + "\t\t- Number of label per sample: 1.0000\n", + "\t\t- Number of unique label: 2\n", + "\t\t- Distribution of label (Top-10): [(0, 26), (1, 1)]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "{'visit_id': '122098',\n", + " 'patient_id': '10059',\n", + " 'conditions': ['5715',\n", + " '452',\n", + " '07070',\n", + " '2800',\n", + " '45620',\n", + " '0389',\n", + " '99592',\n", + " '78552',\n", + " '51881',\n", + " '5724',\n", + " '5849',\n", + " '5859',\n", + " '78951',\n", + " '5723'],\n", + " 'procedures': ['4513',\n", + " '3891',\n", + " '3893',\n", + " '9672',\n", + " '9915',\n", + " '3895',\n", + " '3995',\n", + " '5491'],\n", + " 'conditions_text': ['Cirrhosis of liver without mention of alcohol',\n", + " 'Portal vein thrombosis',\n", + " 'Unspecified viral hepatitis C without hepatic coma',\n", + " 'Iron deficiency anemia secondary to blood loss (chronic)',\n", + " 'Esophageal varices in diseases classified elsewhere, with bleeding',\n", + " 'Unspecified septicemia',\n", + " 'Severe sepsis',\n", + " 'Septic shock',\n", + " 'Acute respiratory failure',\n", + " 'Hepatorenal syndrome',\n", + " 'Acute kidney failure, unspecified',\n", + " 'Chronic kidney disease, unspecified',\n", + " 'Malignant ascites',\n", + " 'Portal hypertension'],\n", + " 'procedures_text': ['Other endoscopy of small intestine',\n", + " 'Arterial catheterization',\n", + " 'Venous catheterization, not elsewhere classified',\n", + " 'Continuous invasive mechanical ventilation for 96 consecutive hours or more',\n", + " 'Parenteral infusion of concentrated nutritional substances',\n", + " 'Venous catheterization for renal dialysis',\n", + " 'Hemodialysis',\n", + " 'Percutaneous abdominal drainage'],\n", + " 'drugs': ['A12C',\n", + " 'B05X',\n", + " 'V04C',\n", + " 'A02B',\n", + " 'J01F',\n", + " 'N05C',\n", + " 'N01A',\n", + " 'A12A',\n", + " 'H01C',\n", + " 'B02B',\n", + " 'V06D',\n", + " 'A10A',\n", + " 'J01M',\n", + " 'C01C',\n", + " 'B01A',\n", + " 'B05B',\n", + " 'C05B',\n", + " 'J01X',\n", + " 'A06A',\n", + " 'B05A',\n", + " 'H02A',\n", + " 'N02A',\n", + " 'A04A'],\n", + " 'drugs_text': ['Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'NS BASE 500mL Bag IV 0.0',\n", + " 'Pantoprazole Sodium MAIN 40mg Vial IV 0.0',\n", + " 'Erythromycin MAIN 1000 mg Vial IV 0.0',\n", + " 'Midazolam HCl MAIN 2mg/2mL Vial IV 0.0',\n", + " 'Fentanyl Citrate MAIN 100mcg/2mL Amp IV 0.0',\n", + " 'Calcium Gluconate MAIN 1g/10mL Vial IV 0.0',\n", + " 'NS BASE 150mL Bab IV 0.0',\n", + " 'Octreotide Acetate MAIN 50mcg/mL-1mL IV 0.0',\n", + " 'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n", + " 'Calcium Gluconate MAIN 1g/10mL Vial IV 0.0',\n", + " 'Phytonadione MAIN 5mg Tab PO 0.0',\n", + " 'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n", + " 'Dextrose 50% MAIN 50mL Syringe IV 0.0',\n", + " 'Insulin MAIN 10mL Vial IV 0.0',\n", + " 'Insulin MAIN 10mL Vial IV 0.0',\n", + " 'Calcium Chloride MAIN 1g/10mL Syringe IV 0.0',\n", + " 'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n", + " 'NS BASE 500mL Bag IV 0.0',\n", + " 'NS BASE 500mL Bag IV 0.0',\n", + " 'Ciprofloxacin MAIN 400mg PM Bag IV 0.0',\n", + " 'Norepinephrine MAIN 4mg/4mL Amp IV DRIP 0.0',\n", + " 'Heparin Flush CVL (100 units/ml) MAIN 100Units/mL-5mL Syringe IV 0.0',\n", + " 'D5W BASE 100mL Bag IV DRIP 0.0',\n", + " 'D5W BASE 250mL Bag IV DRIP 0.0',\n", + " 'Octreotide Acetate MAIN 0.1 mg Amp IV DRIP 0.0',\n", + " 'Fentanyl Citrate MAIN 2.5mg/50mL Vial IV DRIP 0.0',\n", + " 'D5W BASE 250mL Bag IV DRIP 0.0',\n", + " 'Midazolam HCl MAIN 50mg/10mL Vial IV DRIP 0.0',\n", + " 'D5W BASE 100mL Bag IV DRIP 0.0',\n", + " 'Heparin Flush CRRT (5000 Units/mL) MAIN 5000 Units / mL- 1mL Vial IV 0.0',\n", + " 'NS BASE 500mL Bag IV 0.0',\n", + " 'Amino Acids 4.25% W/ Dextrose 5% BASE 1000ml Bag IV 0.0',\n", + " 'Amino Acids 4.25% W/ Dextrose 5% BASE 1000ml Bag IV 0.0',\n", + " 'Vancomycin HCl MAIN 1g Frozen Bag IV 0.0',\n", + " 'Sodium Bicarbonate MAIN 50mEq Vial IV 0.0',\n", + " 'Lactulose MAIN 1000ml Bottle PR 0.0',\n", + " 'Sodium Polystyrene Sulfonate MAIN 15g/60mL Bottle PR 0.0',\n", + " 'D5W BASE 3000 mL Bag IV 0.0',\n", + " 'Sodium Bicarbonate MAIN 50mEq Vial IV 0.0',\n", + " 'D5W BASE 1000mL Bag IV 0.0',\n", + " 'Insulin Human Regular MAIN 100Units/mL; 10mL Vial IV DRIP 0.0',\n", + " 'Insulin Human Regular MAIN 100Units/mL; 10mL Vial IV DRIP 0.0',\n", + " 'NS BASE 100mL Bag IV DRIP 0.0',\n", + " 'Calcium Gluconate MAIN 1g/10mL Vial IV 0.0',\n", + " 'Potassium Chloride MAIN 10mEq/100mL Premix IV 0.0',\n", + " 'D5W BASE 500mL Bag IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Magnesium Sulfate MAIN 1g/2mL Vial IV 0.0',\n", + " 'Vancomycin HCl MAIN 1g Frozen Bag IV 0.0',\n", + " 'NS BASE 250mL Bag IV 0.0',\n", + " 'Heparin Sodium MAIN 25,000 unit Premix Bag IV 0.0',\n", + " 'Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0',\n", + " 'Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0',\n", + " 'Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0',\n", + " 'Potassium Phosphate MAIN 3mM/mL-15mL IV 0.0',\n", + " 'Albumin 25% (12.5 g) MAIN 25%-50mL IV 0.0',\n", + " 'Hydrocortisone Na Succ. MAIN 100mg/2mL Vial IV 0.0',\n", + " 'NS BASE 500mL Bag IV 0.0',\n", + " 'Morphine Sulfate MAIN 100mg/100mL Premix IV DRIP 0.0',\n", + " 'Scopolamine Patch MAIN 1.5mg Patch TP 0.0',\n", + " 'Scopolamine Patch MAIN 1.5mg Patch TP 0.0'],\n", + " 'label': 0}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "mor_dataset = mimic3base.set_task(mortality_pred_task)\n", "mor_dataset.stat()\n", @@ -832,7 +1256,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "e67dcd2f-cb6f-4b21-b692-4526d6aff390", "metadata": { "jupyter": { @@ -990,7 +1414,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 13, "id": "924c8e3c-0482-4688-b2d6-d83d490060a9", "metadata": { "tags": [] @@ -1101,52 +1525,58 @@ "\n", "def per_visit_collate_fn(data):\n", " \"\"\"\n", - " Collate the the list of samples into batches. For each patient, you need to pad the diagnosis\n", - " sequences to the sample shape (max # visits, len(freq_codes)). The padding infomation\n", + " TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis\n", + " sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation\n", " is stored in `mask`.\n", " \n", " Arguments:\n", " data: a list of samples fetched from `CustomDataset`\n", " \n", " Outputs:\n", - " x: a tensor of shape (# patiens, max # visits, len(freq_codes)) of type torch.float\n", - " masks: a tensor of shape (# patiens, max # visits) of type torch.bool\n", + " x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long\n", + " masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool\n", + " rev_x: same as x but in reversed time. This will be used in our RNN model for masking \n", + " rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking\n", " y: a tensor of shape (# patiens) of type torch.float\n", " \n", " Note that you can obtains the list of diagnosis codes and the list of hf labels\n", " using: `sequences, labels = zip(*data)`\n", " \"\"\"\n", "\n", - " sequences, labels = zip(*data)\n", + " # sequences, labels = zip(*data)\n", + " samples = data\n", "\n", " y = torch.tensor(labels, dtype=torch.float)\n", " \n", " num_patients = len(sequences)\n", " num_visits = [len(patient) for patient in sequences]\n", + " num_codes = [len(visit) for patient in sequences for visit in patient]\n", + "\n", " max_num_visits = max(num_visits)\n", + " max_num_codes = max(num_codes)\n", " \n", - " x = torch.zeros((num_patients, max_num_visits, len(freq_codes)), dtype=torch.float) \n", + " x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n", + " rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)\n", + " masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n", + " rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)\n", " for i_patient, patient in enumerate(sequences):\n", " kMaxVisits = len(patient)\n", " for j_visit, visit in enumerate(patient):\n", + " \"\"\"\n", + " TODO: update `x`, `rev_x`, `masks`, and `rev_masks`\n", + " \"\"\"\n", " l = len(visit)\n", - " for code in visit:\n", - " \"\"\"\n", - " TODO: 1. check if code is in freq_codes;\n", - " 2. obtain the code index using code2idx;\n", - " 3. set the correspoindg element in x to 1.\n", - " \"\"\"\n", - " try:\n", - " idx = code2idx[code]\n", - " x[i_patient, j_visit, idx] = 1\n", - " except KeyError as e:\n", - " pass\n", + " x[i_patient, j_visit, 0:l] = torch.LongTensor(visit)\n", + " rev_x[i_patient, kMaxVisits-1-j_visit, 0:l] = torch.LongTensor(visit)\n", + " masks[i_patient, j_visit, 0:l] = 1\n", + " rev_masks[i_patient, kMaxVisits-1-j_visit, 0:l] = 1\n", + " # print(f\"------ v: {kMaxVisits} ------\")\n", + " # print(x[i_patient, :, ])\n", + " # print(rev_x[i_patient, :, ])\n", + " # print(masks[i_patient, :, ])\n", + " # print(rev_masks[i_patient, :, ])\n", " \n", - " masks = torch.sum(x, dim=-1) > 0\n", - " \n", - " return x, masks, y\n", - "\n", - "\n", + " return x, masks, rev_x, rev_masks, y\n", "\n", "def event_seq_collate_fn(data):\n", " # TBD(botelho3)\n", @@ -1155,7 +1585,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 26, "id": "f6afd1f5-0f3b-4f4b-9c85-f74bb367ef88", "metadata": {}, "outputs": [], @@ -1171,49 +1601,100 @@ " matched to output_size. If int, smaller of image edges is matched\n", " to output_size keeping aspect ratio the same.\n", " \"\"\"\n", + " # https://stackoverflow.com/questions/69517460/bert-get-sentence-embedding\n", + " # https://huggingface.co/docs/transformers/internal/tokenization_utils#transformers.tokenization_utils_base.PreTrainedTokenizerBase\n", + " # https://gmihaila.github.io/tutorial_notebooks/bert_inner_workings/\n", "\n", " def __init__(self, bert_model: Any, embedding_size: int):\n", " assert isinstance(embedding_size, (int, tuple))\n", - " self.embedding_size = embedding_size \n", - " self.bert_model = bert_model\n", - "\n", + " self.bert_config = BertConfig(hidden_size=EMBEDDING_DIM_)\n", + " # self.bert_model = BertModel(self.bert_config).from_pretrained('bert-base-uncased', config=self.bert_config)\n", + " self.bert_model = BertModel(self.bert_config).from_pretrained('bert-base-uncased')\n", + " self.bert_config = self.bert_model.config\n", + " self.bert_model.eval()\n", + " self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', model_max_length=100)\n", + " \n", + " def _tokenize_text(self, text: str) -> str:\n", + " tokenized_text = self.tokenizer.tokenize(text)\n", + " return tokenized_text\n", + " \n", + " def _get_embeddings_of_sentences(self, sentences: List[str]) -> torch.tensor:\n", + " # tokenized_sentences = [self.tokenizer.tokenize(t, padding=True) for t in sentences]\n", + " # print(batch_enc)\n", + " batch_enc = self.tokenizer.batch_encode_plus(\n", + " sentences, padding=True,\n", + " return_attention_mask=True, return_length=True)\n", + " # Convert inputs to PyTorch tensors\n", + " # tokens_tensor = torch.tensor([tokenized_sentences])\n", + " # segments_tensors = torch.tensor([segments_ids])\n", + " # Convert inputs to PyTorch tensors\n", + " batch_enc_tensor = {key:torch.LongTensor(value) for key, value in batch_enc.items()}\n", + " \n", + " with torch.no_grad():\n", + " embeddings = self.bert_model(input_ids=batch_enc_tensor['input_ids'],\n", + " attention_mask=batch_enc_tensor['attention_mask'])\n", + " # embeddings, _ = self.bert_model(**batch_enc)\n", + " # print(f'embeddings:\\n {dir(embeddings)}')\n", + " #attention = encoded['attention_mask'].reshape((lhs.size()[0], lhs.size()[1], -1)).expand(-1, -1, 768)\n", + " return embeddings.last_hidden_state, embeddings.attentions \n", + " \n", " def __call__(self, sample):\n", " condition_embeddings = [None] * len(sample['conditions_text'])\n", " procedure_embeddings = [None] * len(sample['procedures_text'])\n", " drug_embeddings = [None] * len(sample['drugs_text'])\n", - " for cond in sample['conditions_text']:\n", - " #bert_model(cond)\n", - " condition_embeddings.append(cond)\n", - " pass\n", - " for proc in sample['procedures_text']:\n", - " #bert_model(proc)\n", - " procedure_embeddings.append(proc)\n", - " pass\n", - " for drug in sample['drugs_text']:\n", - " #bert_model(drug)\n", - " procedure_embeddings.append(drug)\n", - " pass\n", + " \n", + " condition_embeddings, condition_attentions = self._get_embeddings_of_sentences(sample['conditions_text'])\n", + " # for cond in sample['conditions_text']:\n", + " # #bert_model(cond)\n", + " # condition_embeddings.append(cond)\n", + " # pass\n", + " procedure_embeddings, procedure_attentions = self._get_embeddings_of_sentences(sample['procedures_text'])\n", + " # for proc in sample['procedures_text']:\n", + " # #bert_model(proc)\n", + " # procedure_embeddings.append(proc)\n", + " # pass\n", + " drug_embeddings, drug_attentions = self._get_embeddings_of_sentences(sample['drugs_text'])\n", + " # for drug in sample['drugs_text']:\n", + " # #bert_model(drug)\n", + " # procedure_embeddings.append(drug)\n", + " # pass\n", + " \n", + " condition_embeddings = torch.squeeze(condition_embeddings[:, -1, :], dim=1)\n", + " procedure_embeddings = torch.squeeze(procedure_embeddings[:, -1, :], dim=1)\n", + " drug_embeddings = torch.squeeze(drug_embeddings[:, -1, :], dim=1)\n", + " print(f'ce: {condition_embeddings.shape}'\n", + " f'pe: {procedure_embeddings.shape}'\n", + " f'de: {drug_embeddings.shape}')\n", + " \n", + " stacked = torch.cat([condition_embeddings, procedure_embeddings, drug_embeddings], dim=0)\n", + " summed_embeddings = torch.sum(stacked, dim=0, keepdim=True) # sum across rows\n", + " # normalize the final row (across columns)\n", + " summed_embeddings = torch.nn.functional.normalize(summed_embeddings, dim=1)\n", + " \n", + " print(summed_embeddings.shape)\n", + " assert(summed_embeddings.shape == (1, self.bert_config.hidden_size))\n", + " return summed_embeddings\n", " \n", - " image, landmarks = sample['image'], sample['landmarks']\n", - "\n", - " h, w = image.shape[:2]\n", - " if isinstance(self.output_size, int):\n", - " if h > w:\n", - " new_h, new_w = self.output_size * h / w, self.output_size\n", - " else:\n", - " new_h, new_w = self.output_size, self.output_size * w / h\n", - " else:\n", - " new_h, new_w = self.output_size\n", + "# image, landmarks = sample['image'], sample['landmarks']\n", + "\n", + "# h, w = image.shape[:2]\n", + "# if isinstance(self.output_size, int):\n", + "# if h > w:\n", + "# new_h, new_w = self.output_size * h / w, self.output_size\n", + "# else:\n", + "# new_h, new_w = self.output_size, self.output_size * w / h\n", + "# else:\n", + "# new_h, new_w = self.output_size\n", "\n", - " new_h, new_w = int(new_h), int(new_w)\n", + "# new_h, new_w = int(new_h), int(new_w)\n", "\n", - " img = transform.resize(image, (new_h, new_w))\n", + "# img = transform.resize(image, (new_h, new_w))\n", "\n", - " # h and w are swapped for landmarks because for images,\n", - " # x and y axes are axis 1 and 0 respectively\n", - " landmarks = landmarks * [new_w / w, new_h / h]\n", + "# # h and w are swapped for landmarks because for images,\n", + "# # x and y axes are axis 1 and 0 respectively\n", + "# landmarks = landmarks * [new_w / w, new_h / h]\n", "\n", - " return {'image': img, 'landmarks': landmarks}\n", + "# return {'image': img, 'landmarks': landmarks}\n", " \n", " \n", "class TextEmbedDataset(SampleDataset):\n", @@ -1246,7 +1727,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 27, "id": "1056b60a-5861-4e47-b694-891053cc8470", "metadata": {}, "outputs": [ @@ -1254,26 +1735,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "Generating samples for mortality_pred_task: 100%|██████████████████████████████████| 100/100 [00:00<00:00, 16428.92it/s]\n" - ] - }, - { - "ename": "NameError", - "evalue": "name 'per_visit_collate_fn' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/var/folders/bw/pyw_1xcj0f302h0krt1_f5lm0000gn/T/ipykernel_55991/2642145762.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mBATCH_SIZE_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mSHUFFLE_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0mcollate_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mper_visit_collate_fn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m )\n\u001b[1;32m 23\u001b[0m mort_val_loader = DataLoader(\n", - "\u001b[0;31mNameError\u001b[0m: name 'per_visit_collate_fn' is not defined" + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Generating samples for mortality_pred_task: 100%|███████████████████| 100/100 [00:00<00:00, 11527.25it/s]\n" ] } ], "source": [ - "BATCH_SIZE_ = 32\n", - "EMBEDDING_DIM_ = 128\n", - "SHUFFLE_ = True\n", - "\n", "# Create the transform that will take each sample (visit) in the dataset\n", "# and convert the text description of the visit into a single embedding.\n", "bert_xform = BertTextEmbedTransform(None, EMBEDDING_DIM_)\n", @@ -1347,20 +1816,20 @@ "\n", "# from torch.utils.data import DataLoader\n", "\n", - "# loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)\n", - "# loader_iter = iter(loader)\n", - "# x, masks, y = next(loader_iter)\n", + "# loader = DataLoader(mort_, batch_size=10, collate_fn=collate_fn)\n", + "loader_iter = iter(mort_train_loader)\n", + "x, masks, y = next(loader_iter)\n", "\n", - "# assert x.dtype == torch.float\n", - "# assert y.dtype == torch.float\n", - "# assert masks.dtype == torch.bool\n", + "assert x.dtype == torch.float\n", + "assert y.dtype == torch.float\n", + "assert masks.dtype == torch.bool\n", "\n", - "# assert x.shape == (10, 3, 105)\n", - "# assert y.shape == (10,)\n", - "# assert masks.shape == (10, 3)\n", + "assert x.shape == (10, 3, 105)\n", + "assert y.shape == (10,)\n", + "assert masks.shape == (10, 3)\n", "\n", - "# assert x[0][0].sum() == 9\n", - "# assert masks[0].sum() == 2" + "assert x[0][0].sum() == 9\n", + "assert masks[0].sum() == 2" ] }, {