Skip to main content

ULLME: A Unified Framework for Large Language Model Embeddings with Generation-Augmented Learning

Project description

ULLME: A Unified Framework for Large Language Model Embeddings with Generation-Augmented Learning

License HF Link

ULLME is a flexible, plug-and-play implementation that enables bidirectional attention across various LLMs and supports a range of fine-tuning strategies to learn passage embeddings.

Installation

ULLME can be easily installed via one of the following methods:

Using pip

Coming soon!

From source

git clone https://github.com/nlp-uoregon/ullme.git
cd ullme
pip install -e .
# if you using flash-attention-2 (this is the default for ullme)
pip install flash-attn --no-build-isolation

Usage

ULLME offers follwing main features:

Enabling Bidirectional Attention

ULLME can support enhancing HuggingFace models by adding support for bidirectional processing in decoder-only Large Language Models (LLMs), as well as sequence encoding and pooling operations.

from ullme.models import ULLME

model = ULLME(
    model_name_or_path="mistralai/Mistral-7B-v0.1",
    model_backbone_type="mistral",
    lora_name="ullme-mistral",
    loar_r=16,
    lora_alpha=32,
    )
input_sentence = "This a example sentence."
model_inputs = model.tokenizer(
    [input_sentence], 
    return_tensors='pt'
    )
model_output = model(
    input_ids=model_inputs['input_ids'],
    attention_mask=model_inputs['attention_mask'],
    is_generate=False
    )

The ULLME's returned model is a PyTorch object, providing users with the flexibility to integrate it into various frameworks or pipelines. By default the ULLME model uses the mean pooling strategy. The is_generate parameter plays a crucial role in controlling the attention mechanism: when set to False, the model employs bidirectional attention, optimizing it for dense retrieval tasks, while True reverts the model to causal attention, mimicking the standard Hugging Face Transformer model output.

Fine-tuning Strategies

Our ULLME framework supports multiple fine-tuning strategies

from ullme.trainer import GradCacheTrainer

trainer = GradCacheTrainer(
    con_loss_type='NTXentLoss',
    gen_loss_type='sigmoid', # 'sft'
    use_kl_loss=True
)
trainer.fit_epoch(
    model=model,
    train_loader=train_dataloader,
)

Contrastive Learning (CL)

ULLME enables efficient and effective CL. It comes equipped with a range of advanced features designed to enhance the CL process and optimize performance, such as GradCache, cross-devices contrastive loss computation, miners, ... Note that, ULLME enables CL by default.

Generative manner Fine-tuning

ULLME not only supports Contrastive Learning (CL) but also enables Supervised Fine-Tuning (SFT) and provides a range of preference loss functions to further enhance model performance. The loss functions that can be easily selected through the gen_loss_type argument, currently support sft, sigmoid(i.e., DPO), kto, ipo.

Alignment between Generation-based score and Representation-based score.

In ULLME, we also introduce a novel fine-tuning strategy, GRL, that explicitly aligns the model's understanding of relevance in both embedding and generation spaces through a Kullback-Leibler (KL) divergence loss. You can enbale this by set use_kl_loss=True.

Evaluation on MTEB

from ullme.models import WrappedULLME
from ullme.eval import eval_mteb_dataset

model = WrappedULLME(
    model_name_or_path="mistralai/Mistral-7B-v0.1",
    model_backbone_type="mistral",
    lora_name="ullme-mistral",
    loar_r=16,
    lora_alpha=32,
    model_checkpoint="path/to/your/checkpoint"
    )
eval_result = eval_mteb_dataset(
    model=model,
    dataset_name='MSMARCO',
    langs=['eng'],
    )
>> {'eng': 35.8}

ULLME streamlines the evaluation process by integrating direct support for evaluating LLM-based text embedding models over MTEB. ULLME allows users to select specific datasets and language subsets for evaluation through parameters dataset_name and langs.

Model List

We publish three fine-tuned model using GRL on three popular LLMs: Meta-Llama-3-8B; Mistral-2-7B; Phi-1.5B

Finetuning CLI

To finetune the Meta-Llama-3-8B model, run the following command:

python -m genc.main \
    --config_file scripts/configs/llama.yaml \
    --nodes 4 \
    --devices 1 \
    --gc_chunk_size 4 \
    --output_dir output/ullme-grl-llam3

Evaluation CLI

To evaluate the model on the MTEB benchmark, run the following command:

python -m eval.eval_mteb \
    --config_file scripts/configs/llama.yaml \
    --output_dir output/ullme-grl-llam3/checkpoint.ckpt

Bugs or questions?

If you have any questions about the code, feel free to open an issue on the GitHub repository.

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ullme-0.0.1.tar.gz (52.9 kB view details)

Uploaded Source

Built Distribution

ullme-0.0.1-py3-none-any.whl (86.8 kB view details)

Uploaded Python 3

File details

Details for the file ullme-0.0.1.tar.gz.

File metadata

  • Download URL: ullme-0.0.1.tar.gz
  • Upload date:
  • Size: 52.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for ullme-0.0.1.tar.gz
Algorithm Hash digest
SHA256 72c348a2db44a0e827abae62f1275e1dae8b631bb99b6e61d9a5c269ff961b8f
MD5 70cc9ab626f921020e59ebe31b448a7f
BLAKE2b-256 96072b0086f8c887d0579f325f806594b150cf5f25f946b23e1701945a0a2abf

See more details on using hashes here.

File details

Details for the file ullme-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: ullme-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 86.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for ullme-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7d4e40c84ad0589695760844a74a577028e902188b786069455c2b4e754ccdae
MD5 8c5c59341445205196c1601ffbe2410b
BLAKE2b-256 4cdaa20620ab6bdcee77b410ebebef6d1c93917a1f3405a096712aee1d08ba37

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page