Text Embeddings for Retrieval and RAG based on transformers
Project description
Documentation | 中文 | 日本語
Open-retrievals unify text embedding, retrieval, reranking and RAG. It's easy, flexible and scalable.
- Embedding fine-tuned through point-wise, pairwise, listwise, contrastive learning, and LLM.
- Reranking fine-tuned with Cross Encoder, ColBERT, and LLM.
- Easily build enhanced modular RAG, integrated with Transformers, Langchain, and LlamaIndex.
- Term matching and semantic matching.
- The metrics is MAP in 10% t2-reranking data. Original score of LLM and colbert original is Zero-shot
Installation
Prerequisites
pip install transformers
pip install faiss-cpu # if necessary
pip install peft # if necessary
With pip
pip install open-retrievals
With source code
python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git
Quick-start
Embeddings from pretrained weights
from retrievals import AutoModelForEmbedding
sentences = [
'query: how much protein should a female eat',
'query: summit define',
"passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
]
model_name_or_path = 'intfloat/e5-base-v2'
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
embeddings = model.encode(sentences, normalize_embeddings=True, convert_to_tensor=True)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())
Index building for dense retrieval search
from retrievals import AutoModelForEmbedding, AutoModelForRetrieval
sentences = ['A dog is chasing car.', 'A man is playing a guitar.']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
index_path = './database/faiss/faiss.index'
model = AutoModelForEmbedding.from_pretrained(model_name_or_path)
model.build_index(sentences, index_path=index_path)
query_embed = model.encode("He plays guitar.")
matcher = AutoModelForRetrieval()
dists, indices = matcher.search(query_embed, index_path=index_path)
print(indices)
Rerank using pretrained weights
from retrievals import AutoModelForRanking
model_name_or_path: str = "BAAI/bge-reranker-base"
rerank_model = AutoModelForRanking.from_pretrained(model_name_or_path)
scores_list = rerank_model.compute_score(["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."])
print(scores_list)
RAG with LangChain integration
pip install langchain
pip install langchain_community
pip install chromadb
from retrievals.tools.langchain import LangchainEmbedding, LangchainReranker, LangchainLLM
from retrievals import AutoModelForRanking
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.vectorstores import Chroma as Vectorstore
from langchain.prompts.prompt import PromptTemplate
from langchain.chains import RetrievalQA
persist_directory = './database/faiss.index'
embed_model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
rerank_model_name_or_path = "BAAI/bge-reranker-base"
llm_model_name_or_path = "microsoft/Phi-3-mini-128k-instruct"
embeddings = LangchainEmbedding(model_name=embed_model_name_or_path)
vectordb = Vectorstore(
persist_directory=persist_directory,
embedding_function=embeddings,
)
retrieval_args = {"search_type" :"similarity", "score_threshold": 0.15, "k": 10}
retriever = vectordb.as_retriever(**retrieval_args)
ranker = AutoModelForRanking.from_pretrained(rerank_model_name_or_path)
reranker = LangchainReranker(model=ranker, top_n=3)
compression_retriever = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=retriever
)
llm = LangchainLLM(model_name_or_path=llm_model_name_or_path)
RESPONSE_TEMPLATE = """[INST]
<>
You are a helpful AI assistant. Use the following pieces of context to answer the user's question.<>
Anything between the following `context` html blocks is retrieved from a knowledge base.
{context}
REMEMBER:
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
- Let's take a deep breath and think step-by-step.
Question: {question}[/INST]
Helpful Answer:
"""
PROMPT = PromptTemplate(template=RESPONSE_TEMPLATE, input_variables=["context", "question"])
qa_chain = RetrievalQA.from_chain_type(
llm,
chain_type='stuff',
retriever=compression_retriever,
chain_type_kwargs={
"verbose": True,
"prompt": PROMPT,
}
)
user_query = 'Introduce this'
response = qa_chain({"query": user_query})
print(response)
Embedding fine-tuning
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
batch_size: int = 128
epochs: int = 3
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'document'})
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="cls")
# model = model.set_train_type('pointwise') # 'pointwise', 'pairwise', 'listwise'
optimizer = AdamW(model.parameters(), lr=5e-5)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
training_arguments = TrainingArguments(
output_dir='./checkpoints',
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
remove_unused_columns=False,
)
trainer = RetrievalTrainer(
model=model,
args=training_arguments,
train_dataset=train_dataset,
data_collator=PairCollator(tokenizer, query_max_length=128, document_max_length=128),
loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)),
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()
Reranking Fine-tuning
from transformers import AutoTokenizer, TrainingArguments, get_cosine_schedule_with_warmup, AdamW
from retrievals import RerankCollator, AutoModelForRanking, RerankTrainer, RerankTrainDataset
model_name_or_path: str = "microsoft/deberta-v3-base"
max_length: int = 128
learning_rate: float = 3e-5
batch_size: int = 4
epochs: int = 3
train_dataset = RerankTrainDataset('./t2rank.json', positive_key='pos', negative_key='neg')
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForRanking.from_pretrained(model_name_or_path)
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
training_args = TrainingArguments(
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
num_train_epochs=epochs,
output_dir='./checkpoints',
remove_unused_columns=False,
)
trainer = RerankTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=RerankCollator(tokenizer, max_length=max_length),
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()
Reference & Acknowledge
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
open-retrievals-0.0.10.tar.gz
(67.3 kB
view details)
Built Distribution
File details
Details for the file open-retrievals-0.0.10.tar.gz
.
File metadata
- Download URL: open-retrievals-0.0.10.tar.gz
- Upload date:
- Size: 67.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b8123cbbd18b0de02c8cd96d3cd7d381f74d0f5ece87282a17a461057068286e |
|
MD5 | ef7c36f007c9a5c423eb05cd47050d2a |
|
BLAKE2b-256 | a0c075711cceaf2612931a4a24a504676614b491bc9a397133ddddfb32c3d42b |
File details
Details for the file open_retrievals-0.0.10-py3-none-any.whl
.
File metadata
- Download URL: open_retrievals-0.0.10-py3-none-any.whl
- Upload date:
- Size: 89.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4f8057e8463a139a8d8a8c9eeb5ce1b2a0ea55d98f84d9844bae6591caaea46e |
|
MD5 | aeb0fc3a3a37271be9b4ff0a8d71ce13 |
|
BLAKE2b-256 | bbae9ad5b5e1cd2304a91f101bd9bf1d9538f97b03b3077bcd5edc68e333aec7 |