Skip to main content

The official llm2vec library

Project description

LLM2Vec: Large Language Models Are Secretly Powerful Text Encoders

LLM2Vec is a simple recipe to convert decoder-only LLMs into text encoders. It consists of 3 simple steps: 1) enabling bidirectional attention, 2) training with masked next token prediction, and 3) unsupervised contrastive learning. The model can be further fine-tuned to achieve state-of-the-art performance.

Installation

To use LLM2Vec, first install the llm2vec package from PyPI.

pip install llm2vec

You can also directly install it from our code by cloning the repository and:

pip install -e .

Getting Started

LLM2Vec class is a wrapper on top of HuggingFace models to support sequence encoding and pooling operations. The steps below showcase an example on how to use the library.

Preparing the model

Here, we first initialize the model and apply MNTP-trained LoRA weights on top. After merging the model with MNTP weights, we can

  • either load the unsupervised-trained LoRA weights (trained with SimCSE objective and wiki corpus)
  • or we can load the model with supervised-trained LoRA weights (trained with contrastive learning and public E5 data).
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft import PeftModel

config = AutoConfig.from_pretrained("McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp")

# Loading base MNTP model, along with custom code that enables bidirectional connections in decoder-only LLMs
tokenizer = AutoTokenizer.from_pretrained("McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp")
config = AutoConfig.from_pretrained("McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", trust_remote_code=True)
model = AutoModel.from_pretrained("McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", trust_remote_code=True, config=config, torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(model, "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp")
model = model.merge_and_unload() # This can take several minutes

# Loading unsupervised-trained LoRA weights. This loads the trained LoRA weights on top of MNTP model. Hence the final weights are -- Base model + MNTP (LoRA) + SimCSE (LoRA).
model = PeftModel.from_pretrained(model, "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-unsup-simcse-mean")

# Or loading supervised-trained LoRA weights
model = PeftModel.from_pretrained(model, "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised")

Applying LLM2Vec wrapper

Then, we define our LLM2Vec encoder model as follows:

from llm2vec import LLM2Vec

l2v = LLM2Vec(model, tokenizer, pooling_mode="mean", max_length=512)

Inference

This model now returns the text embedding for any input in the form of [[instruction1, text1], [instruction2, text2]] or [text1, text2]. While training, we provide instructions for both sentences in symmetric tasks, and only for for queries in asymmetric tasks.

# Encoding queries using instructions
instruction = 'Given a web search query, retrieve relevant passages that answer the query:'
queries = [
    [instruction, 'how much protein should a female eat'],
    [instruction, 'summit define']
]
q_reps = l2v.encode(queries)

# Encoding documents. Instruction are not required for documents
documents = [
    "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
    "Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments."
]
d_reps = l2v.encode(documents)

# Compute cosine similarity
q_reps_norm = torch.nn.functional.normalize(q_reps, p=2, dim=1)
d_reps_norm = torch.nn.functional.normalize(d_reps, p=2, dim=1)
cos_sim = torch.mm(q_reps_norm, d_reps_norm.transpose(0, 1))

print(cos_sim)
"""
tensor([[0.5486, 0.0554],
        [0.0567, 0.5437]])
"""

Model List

Training

Training code will be available soon.

Bugs or questions?

If you have any question about the code, feel free to email Parishad (parishad.behnamghader@mila.quebec) and Vaibhav (vaibhav.adlakha@mila.quebec).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

llm2vec-0.1.2.tar.gz (7.0 kB view details)

Uploaded Source

Built Distribution

llm2vec-0.1.2-py2.py3-none-any.whl (7.3 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file llm2vec-0.1.2.tar.gz.

File metadata

  • Download URL: llm2vec-0.1.2.tar.gz
  • Upload date:
  • Size: 7.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.4

File hashes

Hashes for llm2vec-0.1.2.tar.gz
Algorithm Hash digest
SHA256 f9ff5f0efb1c1cd2e72155f02d85bf95a4d0e1cd0f26e1b71b9d9e3e8711b8d1
MD5 070da8b123fb225849b709af42e9d6db
BLAKE2b-256 0f8b7908b41f6c1f35fba0892c59d0b7f21a1c991992aa9a104ae93dd85e267e

See more details on using hashes here.

File details

Details for the file llm2vec-0.1.2-py2.py3-none-any.whl.

File metadata

  • Download URL: llm2vec-0.1.2-py2.py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.4

File hashes

Hashes for llm2vec-0.1.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 186097e3a6606a350d038285025d5a6c6cb50d2423ecbaf0a7cbdaa6e8e048be
MD5 121c382424eba4316df16ef95a093bad
BLAKE2b-256 f4ea5a1cc30ed2d70628a90e2b1e2cb246bdab8cabf864ced5fc0c7aabf9b2d9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page