Skip to main content

Text Embeddings for Retrieval and RAG based on transformers

Project description

structure

Open-retrievals unify text embedding, retrieval, reranking and RAG. It's easy, flexible and scalable.

  • Embedding fine-tuned through point-wise, pairwise, listwise, contrastive learning, and LLM.
  • Reranking fine-tuned with Cross Encoder, ColBERT, and LLM.
  • Easily build enhanced modular RAG, integrated with Transformers, Langchain, and LlamaIndex.
Exp Model Original Finetuned Demo
embed pairwise finetune bge-base-zh-v1.5 0.657 0.703 Open In Colab
embed LLM finetune (LoRA) Qwen2-1.5B-Instruct 0.546 0.695 Open In Colab
rerank cross encoder bge-reranker-base 0.666 0.706 Open In Colab
rerank colbert bge-m3 0.657 0.695 Open In Colab
rerank LLM (LoRA) Qwen2-1.5B-Instruct 0.531 0.699 Open In Colab

Installation

Prerequisites

pip install transformers
pip install faiss-cpu  # if necessary
pip install peft  # if necessary

With pip

pip install open-retrievals

With source code

python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git

Quick-start

Open In Colab

Embeddings from pretrained weights

from retrievals import AutoModelForEmbedding

sentences = [
    'query: how much protein should a female eat',
    'query: summit define',
    "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
    "passage: Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments."
]
model_name_or_path = 'intfloat/e5-base-v2'
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
embeddings = model.encode(sentences, normalize_embeddings=True, convert_to_tensor=True)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())

Index building for dense retrieval search

from retrievals import AutoModelForEmbedding, AutoModelForRetrieval

sentences = ['A dog is chasing car.', 'A man is playing a guitar.']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
index_path = './database/faiss/faiss.index'
model = AutoModelForEmbedding.from_pretrained(model_name_or_path)
model.build_index(sentences, index_path=index_path)

query_embed = model.encode("He plays guitar.")
matcher = AutoModelForRetrieval()
dists, indices = matcher.search(query_embed, index_path=index_path)
print(indices)

Rerank using pretrained weights

from retrievals import AutoModelForRanking

model_name_or_path: str = "BAAI/bge-reranker-base"
rerank_model = AutoModelForRanking.from_pretrained(model_name_or_path)
scores_list = rerank_model.compute_score(["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."])
print(scores_list)

RAG with LangChain integration

Open In Colab

pip install langchain
pip install langchain_community
pip install chromadb
from retrievals.tools.langchain import LangchainEmbedding, LangchainReranker, LangchainLLM
from retrievals import AutoModelForRanking
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.vectorstores import Chroma as Vectorstore
from langchain.prompts.prompt import PromptTemplate
from langchain.chains import RetrievalQA

persist_directory = './database/faiss.index'
embed_model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
rerank_model_name_or_path = "BAAI/bge-reranker-base"
llm_model_name_or_path = "microsoft/Phi-3-mini-128k-instruct"

embeddings = LangchainEmbedding(model_name=embed_model_name_or_path)
vectordb = Vectorstore(
    persist_directory=persist_directory,
    embedding_function=embeddings,
)
retrieval_args = {"search_type" :"similarity", "score_threshold": 0.15, "k": 10}
retriever = vectordb.as_retriever(**retrieval_args)

ranker = AutoModelForRanking.from_pretrained(rerank_model_name_or_path)
reranker = LangchainReranker(model=ranker, top_n=3)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=reranker, base_retriever=retriever
)

llm = LangchainLLM(model_name_or_path=llm_model_name_or_path)

RESPONSE_TEMPLATE = """[INST]
<>
You are a helpful AI assistant. Use the following pieces of context to answer the user's question.<>
Anything between the following `context` html blocks is retrieved from a knowledge base.

    {context}

REMEMBER:
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
- Let's take a deep breath and think step-by-step.

Question: {question}[/INST]
Helpful Answer:
"""

PROMPT = PromptTemplate(template=RESPONSE_TEMPLATE, input_variables=["context", "question"])

qa_chain = RetrievalQA.from_chain_type(
    llm,
    chain_type='stuff',
    retriever=compression_retriever,
    chain_type_kwargs={
        "verbose": True,
        "prompt": PROMPT,
    }
)

user_query = 'Introduce this'
response = qa_chain({"query": user_query})
print(response)

Fine-tune Embedding

import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss

model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
batch_size: int = 128
epochs: int = 3

train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'document'})
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="cls")
# model = model.set_train_type('pointwise')  # 'pointwise', 'pairwise', 'listwise'
optimizer = AdamW(model.parameters(), lr=5e-5)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)

training_arguments = TrainingArguments(
    output_dir='./checkpoints',
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    remove_unused_columns=False,
)
trainer = RetrievalTrainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
    data_collator=PairCollator(tokenizer, query_max_length=128, document_max_length=128),
    loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)),
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()

Fine-tune cross-encoder reranking

from transformers import AutoTokenizer, TrainingArguments, get_cosine_schedule_with_warmup, AdamW
from retrievals import RerankCollator, AutoModelForRanking, RerankTrainer, RerankTrainDataset

model_name_or_path: str = "microsoft/deberta-v3-base"
max_length: int = 128
learning_rate: float = 3e-5
batch_size: int = 4
epochs: int = 3

train_dataset = RerankTrainDataset('./t2rank.json', positive_key='pos', negative_key='neg')
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForRanking.from_pretrained(model_name_or_path)
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)

training_args = TrainingArguments(
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    num_train_epochs=epochs,
    output_dir='./checkpoints',
    remove_unused_columns=False,
)
trainer = RerankTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=RerankCollator(tokenizer, max_length=max_length),
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()

Fine-tune ColBERT reranking

import os
import transformers
from transformers import (
    AdamW,
    AutoTokenizer,
    TrainingArguments,
    get_cosine_schedule_with_warmup,
)

from retrievals import ColBERT, ColBertCollator, RerankTrainer, RetrievalTrainDataset
from retrievals.losses import ColbertLoss

transformers.logging.set_verbosity_error()
os.environ["WANDB_DISABLED"] = "true"

model_name_or_path: str = "BAAI/bge-m3"
learning_rate: float = 5e-6
batch_size: int = 32
epochs: int = 3
colbert_dim: int = 1024
output_dir: str = './checkpoints'

train_dataset = RetrievalTrainDataset(
    'C-MTEB/T2Reranking', positive_key='positive', negative_key='negative', dataset_split='dev'
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
data_collator = ColBertCollator(
    tokenizer,
    query_max_length=128,
    document_max_length=256,
    positive_key='positive',
    negative_key='negative',
)
model = ColBERT.from_pretrained(
    model_name_or_path,
    colbert_dim=colbert_dim,
    loss_fn=ColbertLoss(use_inbatch_negative=False),
)

optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps
)

training_args = TrainingArguments(
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    num_train_epochs=epochs,
    output_dir=output_dir,
    remove_unused_columns=False,
    logging_steps=100,
)
trainer = RerankTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()

Reference & Acknowledge

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

open-retrievals-0.0.11.tar.gz (69.1 kB view details)

Uploaded Source

Built Distribution

open_retrievals-0.0.11-py3-none-any.whl (91.2 kB view details)

Uploaded Python 3

File details

Details for the file open-retrievals-0.0.11.tar.gz.

File metadata

  • Download URL: open-retrievals-0.0.11.tar.gz
  • Upload date:
  • Size: 69.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for open-retrievals-0.0.11.tar.gz
Algorithm Hash digest
SHA256 80e0515d1a66177932fad66724b33085d92fe646fba91a1999fa03611cf11962
MD5 2850de81ef8b3b029ca8b42d7fe7e0e0
BLAKE2b-256 934bf5cc4c270ebf6cc123ed62d5a162ba24b41a521786338b910bd7792e6831

See more details on using hashes here.

File details

Details for the file open_retrievals-0.0.11-py3-none-any.whl.

File metadata

File hashes

Hashes for open_retrievals-0.0.11-py3-none-any.whl
Algorithm Hash digest
SHA256 9d594dc4e2b70d97680569eac6b7ad9a8fedc530b2c0325be77af3b85338677c
MD5 c60a8ba1b9e7f400cc82b01bb979bf58
BLAKE2b-256 76469088078f5d8d7e84872699e911e64b6893215cf39bc5037d31eb36fa88c7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page