Skip to main content

Tabular-Infused Parameter Efficient Finetuning (tipeft)

Project description

tipeft

Tabular-infused Parameter Efficient Finetuning (tipeft) is a novel PEFT method designed to infuse tabular features into the initialization process of re-parameterization parameter efficient finetuning (PEFT) methods. This provides an element of well-informed and representational capacity towards the newly introduced PEFT parameters, which are usually introduced and initialized independently

Overview of tipeft framework

It is specifically designed for postoperative predictions in clinical care, where predictive and valuable pre-operative tabular features are often under-utilized in language model finetuning. For now, it supports both LoRA and IA3

Requirements

Dependencies

The following Python packages are required for tipeft:

  • torch

  • transformers

  • peft

  • accelerate

  • numpy

  • pandas

  • scikit-learn

  • tqdm

Install dependencies with:

pip install torch transformers peft accelerate numpy pandas scikit-learn tqdm

Note on Pytorch installation

Because PyTorch wheels vary by CUDA version and hardware, it is recommended to install PyTorch manually following the instructions at https://pytorch.org/

System Requirements

tipeft has been tested and verified on the following configuration:

  • OS: Windows 10

  • Python: 3.9.19

  • CUDA: 12.6

Important Notes

  • Environment: Must be run in a Jupyter notebook. Running as a standalone Python script may cause multiprocessing issues.

  • CPU cores: At least 10 CPU cores recommended (uses Pool(processes=10) internally).

  • GPU: CUDA-compatible GPU required.

  • OS: Tested on Windows. Linux/Mac compatibility not yet verified.

Known Compatibility Limitations

  1. Jupyter only - Uses tqdm.notebook which may not display correctly outside Jupyter.

  2. Multiprocessing - May behave differently on Linux/Mac due to different multiprocessing backends.

If you encounter issues on a different setup, please open an issue with your system info.

GPU requirements

tipeft is designed for GPU acceleration.

  • At least 1 GPU is recommended

  • Suggested minimum: 16GB VRAM

  • Memory usage depends on:

    • sequence length

    • model size

    • batch size

    • peft configuration

Installation

To install in python, simply do the following:

pip install tipeft

Usage

train_tabular_infused_IA3

Parameters

  • train (pandas.DataFrame): Training dataframe containing text, label, and tabular feature columns (required)

  • val (pandas.DataFrame): Validation dataframe with same structure as train (required)

  • pretrained_model_name (str): Base model to fine-tune. Supports "emilyalsentzer/Bio_ClinicalBERT" or "microsoft/biogpt" (required)

  • label_col (str): Column name of the binary outcome label. Must contain True/False values. (required)

  • text_col (str): Column name containing the clinical text (required)

  • columns_unique_labels_of_tabular_features (dict): Map feature names to unique values. Use 1 for continuous, >1 for categorical. (required)

  • lr (float): Learning rate for final model training (default: 0.001)

  • num_epochs (int): Number of training epochs (default: 5)

  • lr_of_tabular_infused_features (float): Learning rate for tabular pre-training (default: 0.0001)

Returns

  • model (PeftModel): The trained IA3 model

  • tokenizer (AutoTokenizer): The tokenizer for the model

Example use case

from tipeft import train_tabular_infused_IA3



model, tokenizer = train_tabular_infused_IA3(

    train=train_df,

    val=val_df,

    pretrained_model_name="emilyalsentzer/Bio_ClinicalBERT",

    label_col="in_hospital_mortality",

    text_col="clinical_notes",

    columns_unique_labels_of_tabular_features={

        "gender": 2,

        "insurance": 3,

        "marital_status": 4,

        "anchor_age": 1,

        "anchor_year": 1

    },

    lr=0.001,

    num_epochs=5,

    lr_of_tabular_infused_features=0.0001

)

Notes

  • The label_col must contain boolean values (True/False)

  • Categorical features should have >1 unique labels in columns_unique_labels_of_tabular_features

  • Continuous/numerical features should have 1 as their value in columns_unique_labels_of_tabular_features

  • Ensure all unique values in categorical columns appear in both train and val sets

  • The trained model is saved to trained_models/IA3_{pretrained_model_name}_{label_col}

train_tabular_infused_lora

Parameters

  • train (pandas.DataFrame): Training dataframe containing text, label, and tabular feature columns (required)

  • val (pandas.DataFrame): Validation dataframe with same structure as train (required)

  • pretrained_model_name (str): Base model to fine-tune. Supports "emilyalsentzer/Bio_ClinicalBERT" or "microsoft/biogpt" (required)

  • label_col (str): Column name of the binary outcome label. Must contain True/False values. (required)

  • text_col (str): Column name containing the clinical text (required)

  • columns_unique_labels_of_tabular_features (dict): Map feature names to unique values. Use 1 for continuous, >1 for categorical. (required)

  • lr (float): Learning rate for final model training (default: 0.001)

  • num_epochs (int): Number of training epochs (default: 5)

  • lr_of_tabular_infused_features (float): Learning rate for tabular pre-training (default: 0.0001)

Returns

  • model (PeftModel): The trained IA3 model

  • tokenizer (AutoTokenizer): The tokenizer for the model

Example use case

from tipeft import train_tabular_infused_lora



model, tokenizer = train_tabular_infused_lora(

    train=train_df,

    val=val_df,

    pretrained_model_name="emilyalsentzer/Bio_ClinicalBERT",

    label_col="in_hospital_mortality",

    text_col="clinical_notes",

    columns_unique_labels_of_tabular_features={

        "gender": 2,

        "insurance": 3,

        "marital_status": 4,

        "anchor_age": 1,

        "anchor_year": 1

    },

    lr=0.001,

    num_epochs=5,

    lr_of_tabular_infused_features=0.0001

)

Notes

  • The label_col must contain boolean values (True/False)

  • Categorical features should have >1 unique labels in columns_unique_labels_of_tabular_features

  • Continuous/numerical features should have 1 as their value in columns_unique_labels_of_tabular_features

  • Ensure all unique values in categorical columns appear in both train and val sets

  • The trained model is saved to trained_models/lora_{pretrained_model_name}_{label_col}

Questions?

Contact me at alba@wustl.edu

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tipeft-0.1.3.tar.gz (237.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tipeft-0.1.3-py3-none-any.whl (13.2 kB view details)

Uploaded Python 3

File details

Details for the file tipeft-0.1.3.tar.gz.

File metadata

  • Download URL: tipeft-0.1.3.tar.gz
  • Upload date:
  • Size: 237.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for tipeft-0.1.3.tar.gz
Algorithm Hash digest
SHA256 08320824d32260c726e01fb341e10903548c2496cbca38c323e47391c77dca10
MD5 4addd169e8d8bcf8d4f710a99a0ed1c0
BLAKE2b-256 18468c03f07624a249772a951922f4ee3d5419cefe435311f826bc1920a5ac6d

See more details on using hashes here.

File details

Details for the file tipeft-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: tipeft-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 13.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for tipeft-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 1b51a85122b73a69bc3eb1d6fcb4c39d682a969fbf0d4c0394923d2dbd9e3501
MD5 bd7ca9504d5606f4f0f856a2f75fb699
BLAKE2b-256 3621bd5bad99c5892095ecde3e7fc6cf038293c9d00d7f701e9cd03a6d927112

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page