Extreme Multi-Label Text Classification

Project description


EXplainable EXtreme Multi-Label TeXt Classification:

  • What is XMTC? Extreme Multi-Label Text Classification (XMTC) addresses the problem of automatically assigning each data point with most relevant subset of labels from an extremely large label set. One major application of XMTC is in the global healthcare system, specifically in the context of the International Classification of Diseases (ICD). ICD coding is the process of assigning codes representing diagnoses and procedures performed during a patient visit using clinical notes documented by health professionals.

  • Datasets? Examples of ICD coding dataset: MIMIC-III and MIMIC-IV. Please note that you need to be a credentialated user and complete a training to acces the data.

  • What is xcube? xcube trains and explains XMTC models using LLM fine-tuning.


  • Create new conda environment:
conda create -n xxx python=3.10
conda activate xxx
  • Install PyTorch with cuda enabled: [Optional]
conda search pytorch
output of conda search pytorch

use the build string that matches the python and cuda version, replacing the pytorch version and build string appropriately:

conda install pytorch=2.0.0=cuda118py310h072bc4c pytorch-cuda=11.8 -c pytorch -c nvidia

Update cuda-toolkit:

sudo apt install nvidia-cuda-toolkit

Verify cuda is available: Run python and import torch; torch.cuda.is_available()

  • Install using:
pip install xcube

Configure accelerate by:

accelerate config

How to use

You can either clone the repo and open it in your own machine. Or if you don’t want to setup a python development environment, an even easier and quicker approach is to open this repo using Google Colab. You can open this readme page in Colab using this link.

IN_COLAB = is_colab()
Not running in Google Colab
source_mimic3 = untar_xxx(XURLs.MIMIC3_DEMO)
source_mimic4 = untar_xxx(XURLs.MIMIC4)
path = Path.cwd().parent/f"{'xcube' if IN_COLAB else ''}" # root of the repo
(path/'tmp/models').mkdir(exist_ok=True, parents=True)
tmp = path/'tmp'
# os.chdir( f"{path/'scripts'}") # To launch our train/infer scripts

Check your GPU memory! If you are running this on google colab be sure to turn on the GPU runtime. You should be able to train and infer all the models with atleast 16GB of memory. However, note that training the full versions of the datasets from scratch requires atleast 48GB memory.

GPU: Quadro RTX 8000
You are using 0.0 GB
Total GPU memory = 44.99969482421875 GB

Train and Infer on MIMIC3-rare50

MIMIC3-rare50 refers to a split of MIMIC-III that contains the 50 most rare codes (Refer to Knowledge Injected Prompt Based Fine-Tuning for Multi-label Few-shot ICD Coding for split creation).

data = join_path_file('mimic3-9k_rare50', source_mimic3, ext='.csv')
!head -n 1 {data}
df = df = pd.read_csv(data,
                 names=['subject_id', 'hadm_id', 'text', 'labels', 'length', 'is_valid', 'split'],
                 dtype={'subject_id': str, 'hadm_id': str, 'text': str, 'labels': str, 'length': np.int64, 'is_valid': bool, 'split': str})
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;

.dataframe thead th {
    text-align: right;
subject_id hadm_id text labels length is_valid split
0 2707 100626 admission date discharge date date of birth sex f service nsu history of present illness the patient is a year old patient with down syndrome who was transferred to hospital3 hospital for an expanding left subdural hematoma with change in mental status and aspiration pneumonia allergies the patient has no known allergies physical exam temp bp heart rate respiratory rate sats percent on room air the patient was awake noncommunicative at baseline attends examiner noncooperative pupils down to mm and briskly reactive eoms full face symmetric follows commands in the upper extremity moves the l... 318.2 334 False train
1 16650 176541 admission date discharge date date of birth sex m service surgery allergies mirtazapine attending first name3 lf chief complaint multiple self inflicted stab wounds major surgical or invasive procedure closure of stab wounds history of present illness patient was found in a park non verbal at the scene after self inflicted stab wounds to l chest x past medical history depression si sa x2 dm2 htn social history depression quit lost job years ago after a divorce lost health insurance afterwards multipl suicide attempts family history non contributory physical exam heent wnl cv rrr no mrg che... 34.71 424 False train

To launch the training of an XMTC model on MIMIC3-rare50:

!./ --script_list_file script_list_mimic3_rare50train

Train and Infer on MIMIC3-top50

MIMIC3-top50 refers to a split of MIMIC-III that contains 50 most frequent codes (Refer to Explainable Prediction of Medical Codes from Clinical Text for split creation)

data = join_path_file('mimic3-9k_top50', source_mimic3, ext='.csv')
!head -n 1 {data}
df = pd.read_csv(data,
                 names=['subject_id', 'hadm_id', 'text', 'labels', 'length', 'is_valid', 'split'],
                 dtype={'subject_id': str, 'hadm_id': str, 'text': str, 'labels': str, 'length': np.int64, 'is_valid': bool, 'split': str})
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;

.dataframe thead th {
    text-align: right;
subject_id hadm_id text labels length is_valid split
0 86006 111912 admission date discharge date date of birth sex f service surgery allergies patient recorded as having no known allergies to drugs attending first name3 lf chief complaint 60f on coumadin was found slightly drowsy tonight then fell down stairs paramedic found her unconscious and she was intubated w o any medication head ct shows multiple iph transferred to hospital1 for further eval major surgical or invasive procedure none past medical history her medical history is significant for hypertension osteoarthritis involving bilateral knee joints with a dependence on cane for ambulation chronic... 414.01;427.31;V58.61;401.9;96.71 230 False dev
1 85950 189769 admission date discharge date service neurosurgery allergies sulfa sulfonamides attending first name3 lf chief complaint cc cc contact info major surgical or invasive procedure none history of present illness hpi 88m who lives with family had fall yesterday today had decline in mental status ems called pt was unresponsive on arrival went to osh head ct showed large r sdh pt was intubated at osh and transferred to hospital1 for further care past medical history cad s p mi in s p cabg in ventricular aneurysm at that time cath in with occluded rca unable to intervene chf reported ef 1st degre... 250.00;403.90;V45.81;96.71;585.9 304 False dev

To infer one our pretrained XMTC models on MIMIC3-top50 (Metrics for inference - Precision@3,5,8,15):

model_fnames = L(source_mimic3.glob("**/*top50*.pth")).map(str)
fname = Path(shutil.copy(model_fnames[2], tmp/'models')).name.split('.')[0]
print(f"We are going to infer model {fname}.")
We are going to infer model mimic3_clas_top50.
!./launches/launch_top50_mimic3 --fname {fname} --no_running_decoder --infer 1

To launch the training of an XMTC model on MIMIC3-top50 from scratch:

!./ --script_list_file script_list_mimic3_top50train

Train and Infer on MIMIC3-full:

MIMIC3-full refers to the full MIMIC-III dataset. (Refer to Explainable Prediction of Medical Codes from Clinical Text for details of how the data was curated)

data = join_path_file('mimic3-9k_full', source_mimic3, ext='.csv')
!head -n 1 {data}
df = pd.read_csv(data,
                 names=['subject_id', 'hadm_id', 'text', 'labels', 'length', 'is_valid', 'split'],
                 dtype={'subject_id': str, 'hadm_id': str, 'text': str, 'labels': str, 'length': np.int64, 'is_valid': bool, 'split': str})
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;

.dataframe thead th {
    text-align: right;
subject_id hadm_id text labels length is_valid split
0 86006 111912 admission date discharge date date of birth sex f service surgery allergies patient recorded as having no known allergies to drugs attending first name3 lf chief complaint 60f on coumadin was found slightly drowsy tonight then fell down stairs paramedic found her unconscious and she was intubated w o any medication head ct shows multiple iph transferred to hospital1 for further eval major surgical or invasive procedure none past medical history her medical history is significant for hypertension osteoarthritis involving bilateral knee joints with a dependence on cane for ambulation chronic... 801.35;348.4;805.06;807.01;998.30;707.24;E880.9;427.31;414.01;401.9;V58.61;V43.64;707.00;E878.1;96.71 230 False dev
1 85950 189769 admission date discharge date service neurosurgery allergies sulfa sulfonamides attending first name3 lf chief complaint cc cc contact info major surgical or invasive procedure none history of present illness hpi 88m who lives with family had fall yesterday today had decline in mental status ems called pt was unresponsive on arrival went to osh head ct showed large r sdh pt was intubated at osh and transferred to hospital1 for further care past medical history cad s p mi in s p cabg in ventricular aneurysm at that time cath in with occluded rca unable to intervene chf reported ef 1st degre... 852.25;E888.9;403.90;585.9;250.00;414.00;V45.81;96.71 304 False dev

Lets’s look at some of the ICD9 codes description:

des = load_pickle(source_mimic3/'code_desc.pkl')
lbl_dict = dict()
for lbl in df.labels[1].split(';'):
    lbl_dict[lbl] = des.get(lbl, 'NF')
pd.DataFrame(lbl_dict.items(), columns=['icd9_code', 'desccription'])
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;

.dataframe thead th {
    text-align: right;
icd9_code desccription
0 852.25 Subdural hemorrhage following injury, without mention of open intracranial wound, with prolonged [more than 24 hours] loss of consciousness, without return to pre-existing conscious level
1 E888.9 Unspecified fall
2 403.90 Hypertensive renal disease, unspecified, without mention of renal failure
3 585.9 Chronic kidney disease, unspecified
4 250.00 type II diabetes mellitus [non-insulin dependent type] [NIDDM type] [adult-onset type] or unspecified type, not stated as uncontrolled, without mention of complication
5 414.00 Coronary atherosclerosis of unspecified type of vessel, native or graft
6 V45.81 Postsurgical aortocoronary bypass status
7 96.71 Continuous mechanical ventilation for less than 96 consecutive hours

To infer one our pretrained XMTC models on MIMIC3-full (Metrics for inference - Precision@3,5,8,15):

model_fnames = L(source_mimic3.glob("**/*full*.pth")).map(str)
fname = Path(shutil.copy(model_fnames[0], tmp/'models')).name.split('.')[0]
print(f"Let's infer the pretrained model {fname}.")
Let's infer the pretrained model mimic3-9k_clas_full.
!./launches/launch_complete_mimic3 --fname {fname} --infer 1 --no_running_decoder

Train and Infer on MIMIC4-full:

MIMIC4-full refers to the full MIMIC-IV dataset using ICD10 codes. (Refer to Automated Medical Coding on MIMIC-III and MIMIC-IV: A Critical Review and Replicability Study for details of how the data was curated)

data = join_path_file('mimic4_icd10_full', source_mimic4, ext='.csv')
!head -n 1 {data}
df = pd.read_csv(data,
                    usecols=['subject_id', '_id', 'text', 'labels', 'num_targets', 'is_valid', 'split'],
                    dtype={'subject_id': str, '_id': str, 'text': str, 'labels': str, 'num_targets': np.int64, 'is_valid': bool, 'split': str})

Let’s look at some of the descriptions of ICD10 codes:

stripped_codes = [''.join(filter(str.isalnum, s)) for s in df.labels[0].split(';')]
desc = get_description(stripped_codes)
pd.DataFrame(desc.items(), columns=['icd10_code', 'desccription'])
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;

.dataframe thead th {
    text-align: right;
icd10_code desccription
0 E785 Hyperlipidemia, unspecified
1 F0280 ICD-10-PCS code structure
2 G3183 Neurocognitive disorder with Lewy bodies
3 R296 Repeated falls
4 R441 Visual hallucinations
5 Z8546 Personal history of malignant neoplasm of prostate

To infer one our pretrained XMTC models on MIMIC4-full (Metrics for inference - Precision@5,8,15):

model_fname = Path('/home/deb/.xcube/data/mimic4/mimic4_icd10_clas_full.pth')
fname = Path(shutil.copy(model_fname, tmp/'models')).name.split('.')[0]
print(f"Let's infer the pretrained model {fname}.")
Let's infer the pretrained model mimic4_icd10_clas_full.
!./launches/launch_complete_mimic4_icd10 --fname mimic4_icd10_clas_full --no_running_decoder --infer 1
fname is: mimic4_icd10_clas_full
infer is: 1
diff_inattn is: 40
lin_sgdr_lr0 is: 1e-1
l2r_sgdr_lr0 is: 1e-1
plant is false
Training XMTC without Stateful Decoder
All arguments:
--epochs=[0, 0, 0, 0, 6]
--lrs_linattn=[(6e-2,1e-6), (1e-2,1e-6), (1e-2, 1e-6), (1e-2,1e-6), (1e-6,1e-6)]
--lrs_plant=[(6e-2,1e-6), (1e-2,1e-2), (1e-2, 1e-2), (1e-3,1e-3), (1e-5,1e-5)]
--lrs_sgdr_linattn=[(1e-1,1e-1), (1e-1,1e-1), (1e-2, 1e-6), (1e-2,1e-6), (1e-6,1e-6)]
--lrs_sgdr_plant=[(1e-1,1e-1), (1e-1,1e-1), (1e-1,1e-1), (1e-2,1e-6), (1e-6,1e-6)]
--wd_linattn=[0.01, 0.01, 0.01, 0.3]
--wd_plant=[0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.01]
--wd_mul_plant=[1.0, 1.0, 1.0, 1.0, 30.0]
--metrics=partial(precision_at_k, k=5); partial(precision_at_k, k=8); partial(precision_at_k, k=15)
The following values were not passed to `accelerate launch` and had defaults used instead:
    `--num_machines` was set to a value of `1`
    `--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
Rank[0] Run: 0; epochs: 6; lr: 0.01; bs: 8
Training with running_decoder=False
best so far = None                                       
loss = 0.005077285226434469                                                                                              
precision_at_k = 0.7713261286738703
precision_at_k = 0.6909339965660034
precision_at_k = 0.5442615224051447
best so far = None


This repository is my attempt to create Extreme Multi-Label Text Classifiers using Language Model Fine-Tuning as proposed by Jeremy Howard and Sebastian Ruder in ULMFit. I am also heavily influenced by the’s course Practical Deep Learning for Coders and the excellent library fastai. I have adopted the style of coding from fastai using the jupyter based dev environment nbdev. Since this is one of my fast attempt to create a full fledged python library, I have at times replicated implementations from fastai with some modifications. A big thanks to Jeremy and his team from for everything they have been doing to make AI accessible to everyone.

