Skip to main content

ULLME: A Unified Framework for Large Language Model Embeddings with Generation-Augmented Learning

Project description

ULLME: A Unified Framework for Large Language Model Embeddings with Generation-Augmented Learning

License HF Link

ULLME is a flexible, plug-and-play implementation that enables bidirectional attention across various LLMs and supports a range of fine-tuning strategies to learn passage embeddings.

Installation

ULLME can be easily installed via one of the following methods:

Using pip

pip install ullme
# if you using flash-attention-2 (this is the default for ullme)
pip install flash-attn --no-build-isolation

From source

git clone https://github.com/nlp-uoregon/ullme.git
cd ullme
pip install -e .
# if you using flash-attention-2 (this is the default for ullme)
pip install flash-attn --no-build-isolation

Usage

ULLME offers follwing main features:

Enabling Bidirectional Attention

ULLME can support enhancing HuggingFace models by adding support for bidirectional processing in decoder-only Large Language Models (LLMs), as well as sequence encoding and pooling operations.

from ullme.models import ULLME

model = ULLME(
    model_name_or_path="mistralai/Mistral-7B-v0.1",
    model_backbone_type="mistral",
    lora_name="ullme-mistral",
    loar_r=16,
    lora_alpha=32,
    )
input_sentence = "This a example sentence."
model_inputs = model.tokenizer(
    [input_sentence], 
    return_tensors='pt'
    )
model_output = model(
    input_ids=model_inputs['input_ids'],
    attention_mask=model_inputs['attention_mask'],
    is_generate=False
    )

The ULLME's returned model is a PyTorch object, providing users with the flexibility to integrate it into various frameworks or pipelines. By default the ULLME model uses the mean pooling strategy. The is_generate parameter plays a crucial role in controlling the attention mechanism: when set to False, the model employs bidirectional attention, optimizing it for dense retrieval tasks, while True reverts the model to causal attention, mimicking the standard Hugging Face Transformer model output.

Fine-tuning Strategies

Our ULLME framework supports multiple fine-tuning strategies

from ullme.trainer import GradCacheTrainer

trainer = GradCacheTrainer(
    con_loss_type='NTXentLoss',
    gen_loss_type='sigmoid', # 'sft'
    use_kl_loss=True
)
trainer.fit_epoch(
    model=model,
    train_loader=train_dataloader,
)

Contrastive Learning (CL)

ULLME enables efficient and effective CL. It comes equipped with a range of advanced features designed to enhance the CL process and optimize performance, such as GradCache, cross-devices contrastive loss computation, miners, ... Note that, ULLME enables CL by default.

Generative manner Fine-tuning

ULLME not only supports Contrastive Learning (CL) but also enables Supervised Fine-Tuning (SFT) and provides a range of preference loss functions to further enhance model performance. The loss functions that can be easily selected through the gen_loss_type argument, currently support sft, sigmoid(i.e., DPO), kto, ipo.

Alignment between Generation-based score and Representation-based score.

In ULLME, we also introduce a novel fine-tuning strategy, GRL, that explicitly aligns the model's understanding of relevance in both embedding and generation spaces through a Kullback-Leibler (KL) divergence loss. You can enbale this by set use_kl_loss=True.

Evaluation on MTEB

from ullme.models import WrappedULLME
from ullme.eval import eval_mteb_dataset

model = WrappedULLME(
    model_name_or_path="mistralai/Mistral-7B-v0.1",
    model_backbone_type="mistral",
    lora_name="ullme-mistral",
    loar_r=16,
    lora_alpha=32,
    model_checkpoint="path/to/your/checkpoint"
    )
eval_result = eval_mteb_dataset(
    model=model,
    dataset_name='MSMARCO',
    langs=['eng'],
    )
>> {'eng': 35.8}

ULLME streamlines the evaluation process by integrating direct support for evaluating LLM-based text embedding models over MTEB. ULLME allows users to select specific datasets and language subsets for evaluation through parameters dataset_name and langs.

Model List

We publish three fine-tuned model using GRL on three popular LLMs: Meta-Llama-3-8B; Mistral-2-7B; Phi-1.5B

Finetuning CLI

To finetune the Meta-Llama-3-8B model, run the following command:

python -m genc.main \
    --config_file scripts/configs/llama.yaml \
    --nodes 4 \
    --devices 1 \
    --gc_chunk_size 4 \
    --output_dir output/ullme-grl-llam3

Evaluation CLI

To evaluate the model on the MTEB benchmark, run the following command:

python -m eval.eval_mteb \
    --config_file scripts/configs/llama.yaml \
    --output_dir output/ullme-grl-llam3/checkpoint.ckpt

Bugs or questions?

If you have any questions about the code, feel free to open an issue on the GitHub repository.

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ullme-0.0.3.tar.gz (53.1 kB view details)

Uploaded Source

Built Distribution

ullme-0.0.3-py3-none-any.whl (87.0 kB view details)

Uploaded Python 3

File details

Details for the file ullme-0.0.3.tar.gz.

File metadata

  • Download URL: ullme-0.0.3.tar.gz
  • Upload date:
  • Size: 53.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for ullme-0.0.3.tar.gz
Algorithm Hash digest
SHA256 a17bb146d6b2faecb191e70dd591dcdc4f8d5ce1ff51b1c7fe4f5d64dc96147b
MD5 c85a19d40c256edf20724dc18069422c
BLAKE2b-256 7f9aa4a944c0e1b86b754ebaf0aa2f6c54abc5e4a167d635cd27970a21238279

See more details on using hashes here.

File details

Details for the file ullme-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: ullme-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 87.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for ullme-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 f645d9d8d89c90cf49cbd97a80932a97d5513f5514dcfe3ed8f91541d5d21efb
MD5 6e8c0bacd6e038a74c4d0504b8854c68
BLAKE2b-256 fcc1f09ba518c685f75e4765ca8799fe02be51b4461eec35da678803d96fff69

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page