Skip to main content

Text Embeddings for Retrieval and RAG based on transformers

Project description


LICENSE PyPI Version Build Status Lint Status Docs Status Code Coverage

Documentation | Tutorials | 中文

Open-Retrievals is an easy-to-use python framework getting SOTA text embeddings, oriented to information retrieval and LLM retrieval augmented generation, based on PyTorch and Transformers.

  • Contrastive learning enhanced embeddings
  • LLM embeddings

Installation

Prerequisites

pip install transformers
pip install faiss-cpu
pip install peft

With pip

pip install open-retrievals

Quick-start

from retrievals import AutoModelForEmbedding, AutoModelForRetrieval

# Example list of documents
documents = [
    "Open-retrievals is a text embedding libraries",
    "I can use it simply with a SOTA RAG application.",
]

# This will trigger the model download and initialization
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path)

embeddings = model.encode(documents)
len(embeddings) # Vector of 384 dimensions

Usage

Build Index and Retrieval

from retrievals import AutoModelForEmbedding, AutoModelForRetrieval

sentences = ['A dog is chasing car.', 'A man is playing a guitar.']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path)
model.build_index(sentences)

matcher = AutoModelForRetrieval()
results = matcher.faiss_search("He plays guitar.")

Rerank

from transformers import AutoTokenizer
from retrievals import RerankCollator, RerankModel, RerankTrainer, RerankDataset

train_dataset = RerankDataset(args=data_args)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)

model = RerankModel(
    model_args.model_name_or_path,
    pooling_method="mean"
)
optimizer = get_optimizer(model, lr=5e-5, weight_decay=1e-3)

lr_scheduler = get_scheduler(optimizer, num_train_steps=int(len(train_dataset) / 2 * 1))

trainer = RerankTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=RerankCollator(tokenizer, max_length=data_args.query_max_len),
)
trainer.optimizer = optimizer
trainer.scheduler = lr_scheduler
trainer.train()

RAG with LangChain

  • Prerequisites
pip install langchain
  • Server
from retrievals.tools.langchain import LangchainEmbedding, LangchainReranker
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.vectorstores import Chroma as Vectorstore


class DenseRetrieval:
    def __init__(self, persist_directory):
        embeddings = LangchainEmbedding(model_name="BAAI/bge-large-zh-v1.5")
        vectordb = Vectorstore(
            persist_directory=persist_directory,
            embedding_function=embeddings,
        )
        retrieval_args = {"search_type" :"similarity", "score_threshold": 0.15, "k": 30}
        self.retriever = vectordb.as_retriever(retrieval_args)

        reranker_args = {
            "model": "../../inputs/bce-reranker-base_v1",
            "top_n": 7,
            "device": "cuda",
            "use_fp16": True,
        }
        self.reranker = LangchainReranker(**reranker_args)
        self.compression_retriever = ContextualCompressionRetriever(
            base_compressor=self.reranker, base_retriever=self.retriever
        )

    def query(
        self,
        question: str
    ):
        docs = self.compression_retriever.get_relevant_documents(question)
        return docs

Use Pretrained sentence embedding

from retrievals import AutoModelForEmbedding

sentences = ["Hello world", "How are you?"]
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path, pooling_method="mean", normalize_embeddings=True)
sentence_embeddings = model.encode(sentences, convert_to_tensor=True)
print(sentence_embeddings)

Finetune transformers by contrastive learning

from transformers import AutoTokenizer
from retrievals import AutoModelForEmbedding, AutoModelForRetrieval, RetrievalTrainer, PairCollator, TripletCollator
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
from retrievals.data import  RetrievalDataset, RerankDataset


train_dataset = RetrievalDataset(args=data_args)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)

model = AutoModelForEmbedding(
    model_args.model_name_or_path,
    pooling_method="cls"
)
optimizer = get_optimizer(model, lr=5e-5, weight_decay=1e-3)

lr_scheduler = get_scheduler(optimizer, num_train_steps=int(len(train_dataset) / 2 * 1))

trainer = RetrievalTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=TripletCollator(tokenizer, max_length=data_args.query_max_len),
    loss_fn=TripletLoss(),
)
trainer.optimizer = optimizer
trainer.scheduler = lr_scheduler
trainer.train()

Finetune LLM for embedding by Contrastive learning

from retrievals import AutoModelForEmbedding

model = AutoModelForEmbedding(
    "mistralai/Mistral-7B-v0.1",
    pooling_method='cls',
    query_instruction=f'Instruct: Retrieve semantically similar text\nQuery: '
)

Search by Cosine similarity/KNN

from retrievals import AutoModelForEmbedding, AutoModelForRetrieval

query_texts = ['A dog is chasing car.']
passage_texts = ['A man is playing a guitar.', 'A bee is flying low']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path)
query_embeddings = model.encode(query_texts, convert_to_tensor=True)
passage_embeddings = model.encode(passage_texts, convert_to_tensor=True)

matcher = AutoModelForRetrieval(method='cosine')
dists, indices = matcher.similarity_search(query_embeddings, passage_embeddings, top_k=1)

Reference & Acknowledge

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

open-retrievals-0.0.2.tar.gz (35.3 kB view hashes)

Uploaded Source

Built Distribution

open_retrievals-0.0.2-py3-none-any.whl (44.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page