Skip to main content

Neural-Tree

Project description

Neural-Tree

Neural Search

documentation license

Are tree-based indexes the counterpart of standard ANN algorithms for token-level embeddings IR models? Neural-Tree replicate the SIGIR 2023 publication Constructing Tree-based Index for Efficient and Effective Dense Retrieval in order to accelerate ColBERT. Neural-Tree is compatible with Sentence Transformers and TfIdf models as in the original paper.

Neural-Tree creates a tree using hierarchical clustering of documents and then learn embeddings in each node of the tree using paired queries and documents. Additionally, there is the flexibility to input an existing tree structure in JSON format to build the index.

The optimization of the index by Neural-Tree is geared towards maintaining the performance level of the original model while significantly speeding up the search process. It is important to note that Neural-Tree does not modify the underlying model; therefore, it is advisable to initiate tree creation with a model that has already been fine-tuned. Given that Neural-Tree does not alter the model, the index training process is relatively quick.

Installation

We can install neural-tree using:

pip install neural-tree

If we plan to evaluate our model while training install:

pip install "neural-tree[eval]"

Documentation

The complete documentation is available here.

Quick Start

The following code shows how to train a tree index. Let's start by creating a fictional dataset:

documents = [
    {"id": 0, "content": "paris"},
    {"id": 1, "content": "london"},
    {"id": 2, "content": "berlin"},
    {"id": 3, "content": "rome"},
    {"id": 4, "content": "bordeaux"},
    {"id": 5, "content": "milan"},
]

train_queries = [
    "paris is the capital of france",
    "london is the capital of england",
    "berlin is the capital of germany",
    "rome is the capital of italy",
]

train_documents = [
    {"id": 0, "content": "paris"},
    {"id": 1, "content": "london"},
    {"id": 2, "content": "berlin"},
    {"id": 3, "content": "rome"},
]

test_queries = [
    "bordeaux is the capital of france",
    "milan is the capital of italy",
]

Let's train the index using the documents, train_queries and train_documents we have gathered.

import torch
from neural_cherche import models
from neural_tree import clustering, trees, utils

model = models.ColBERT(
    model_name_or_path="raphaelsty/neural-cherche-colbert",
    device="cuda" if torch.cuda.is_available() else "cpu",
)

tree = trees.ColBERT(
    key="id",
    on=["content"],
    model=model,
    documents=documents,
    leaf_balance_factor=100,  # Number of documents per leaf
    branch_balance_factor=5,  # Number children per node
    n_jobs=-1,  # set to 1 with Google Colab
)

optimizer = torch.optim.AdamW(lr=3e-3, params=list(tree.parameters()))

for step, batch_queries, batch_documents in utils.iter(
    queries=train_queries,
    documents=train_documents,
    shuffle=True,
    epochs=50,
    batch_size=32,
):
    loss = tree.loss(
        queries=batch_queries,
        documents=batch_documents,
    )

    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

Let's now duplicate some documents of the tree in order to increase accuracy.

documents_to_leafs = clustering.optimize_leafs(
    tree=tree,
    queries=train_queries + test_queries,
    documents=documents,
)

tree = tree.add(
    documents=documents,
    documents_to_leafs=documents_to_leafs,
)

We are now ready to retrieve documents:

scores = tree(
    queries=["bordeaux", "milan"],
    k_leafs=2,
    k=2,
)

print(scores["documents"])
[
    [
        {"id": 4, "similarity": 5.28, "leaf": "12"},
        {"id": 0, "similarity": 3.17, "leaf": "12"},
    ],
    [
        {"id": 5, "similarity": 5.11, "leaf": "10"},
        {"id": 2, "similarity": 3.57, "leaf": "10"},
    ],
]

Evaluation

We can evaluate the performance of the tree using the following code:

documents, queries_ids, test_queries, qrels = datasets.load_beir_test(
    dataset_name="scifact",
)

candidates = tree(
    queries=test_queries,
    k_leafs=2,
    k=10,
)

scores = utils.evaluate(
    scores=candidates["documents"],
    qrels=qrels,
    queries_ids=queries_ids,
)

print(scores)

Benchmarks

Scifact Dataset
Vanilla Neural-Tree
model HuggingFace Checkpoint ndcg@10 hits@10 hits@1 queries / second ndcg@10 hits@10 hits@1 queries / second Acceleration
TfIdf
Cherche
- 0,61 0,85 0,47 760 0,56 0,82 0,42 1080 +42.11%
SentenceTransformer GPU
Faiss.IndexFlatL2 CPU
sentence-transformers/all-mpnet-base-v2 0,66 0,89 0,53 475 0,66 0,88 0,53 518 +9.05%
ColBERT
Neural-Cherche GPU
raphaelsty/neural-cherche-colbert 0,70 0,92 0,58 3 0,70 0,91 0,59 256 x85

Note that this benchmark do not implement ColBERTV2 efficient retrieval but rather compare ColBERT raw retrieval with Neural-Tree. We could accelerate SentenceTransformer vanilla by using optimized Faiss index.

Contributing

We welcome contributions to Neural-Tree, a tool designed to enhance tree visualization, model node topics, and leverage the tree structure to expedite Large Language Model (LLM) searches. Our focus includes refining the clustering of ColBERT embeddings through hierarchical clustering, which is currently facilitated by TfIdf. Additionally, there's an opportunity to contribute towards optimizing clustering, aiming to achieve comprehensive ColBERT cluster optimization independently of TfIdf.

License

This project is licensed under the terms of the MIT license.

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neural-tree-0.0.1.tar.gz (27.1 kB view details)

Uploaded Source

File details

Details for the file neural-tree-0.0.1.tar.gz.

File metadata

  • Download URL: neural-tree-0.0.1.tar.gz
  • Upload date:
  • Size: 27.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for neural-tree-0.0.1.tar.gz
Algorithm Hash digest
SHA256 2aebe34e0242538fa13d2adc17923ea285156df9013d710537101299424bb857
MD5 b4f7aa7e3669e92d5d6bb628611f6fa8
BLAKE2b-256 8377f12a27ae5e142a74f4c4f31b5db9b3104ba8dad40d2fcba42cb0e9eba3c3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page