Skip to main content

No project description provided

Project description

DraftRetriever

DraftRetriever is an integral component of ADED, a Retrieval-Based Speculative Decoding method that accelerates large language model (LLM) decoding without fine-tuning, using an adaptive draft-verification process. It dynamically adjusts to token probabilities with a tri-gram matrix representation and employs Monte Carlo Tree Search (MCTS) to balance exploration and exploitation, producing accurate drafts quickly. ADED significantly speeds up decoding while maintaining high accuracy, making it ideal for practical applications.

Installation

Prerequisites:

If the provided wheel files are not compatible with your system, ensure you have Rust installed:

curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
pip install maturin

Example

Generate Tri-gram Matrix

import draftretriever
from transformers import AutoTokenizer
from tqdm import tqdm
import json

tokenizer = AutoTokenizer.from_pretrained(model_path)


datastore_path = './datastore_chat_large.idx'
writer = draftretriever.Writer(
    file_path=datastore_path,
    vocab_size=tokenizer.vocab_size,
)

dataset_path = "datastore/ShareGPT_V4.3_unfiltered_cleaned_split.json"
assert dataset_path is not None, "please download the dataset from https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered"
dataset = json.load(open(dataset_path))
total_length = len(dataset)
print("number of samples: ", total_length)
for conversations in tqdm(dataset, total=total_length):
    for sample in conversations['conversations']:
        token_list = tokenizer.encode(sample['value'])
        writer.add_entry(token_list)

writer.finalize()

Search

import draftretriever

datastore = draftretriever.Reader(index_file_path=datastore_path)
retrieved_token_list, _draft_attn_mask, _tree_indices, _draft_position_ids, _retrieve_indices = datastore.search(token_list, choices=max_num_draft)

License

Distributed under the MIT License. See LICENSE for more information.

Acknowledgement

The main framework is from REST

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

draftretriever-0.1.1.tar.gz (9.8 kB view details)

Uploaded Source

Built Distributions

draftretriever-0.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (247.7 kB view details)

Uploaded CPython 3.12 manylinux: glibc 2.17+ x86-64

draftretriever-0.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (256.1 kB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

File details

Details for the file draftretriever-0.1.1.tar.gz.

File metadata

  • Download URL: draftretriever-0.1.1.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: maturin/1.7.4

File hashes

Hashes for draftretriever-0.1.1.tar.gz
Algorithm Hash digest
SHA256 698f3d0af6afdaf22621dd47944f9198d72c244b32ef2cdc67833663a3ee2c12
MD5 9133fcaa011a5206fbb8ccc828f4e636
BLAKE2b-256 831111c254f1d4aae5378c3e27e66a45a13e31425a9b9e922bb5df30247b2a46

See more details on using hashes here.

File details

Details for the file draftretriever-0.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for draftretriever-0.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 803275d84c549400965bb8ed53437c8f5be8082bdc86577a83aef31ef64731f8
MD5 8e135133d261a29cd983afb4ede52dde
BLAKE2b-256 41f016857400d418020438d05f7b897304be8eda8a143f83a6162fe1bb4a1af4

See more details on using hashes here.

File details

Details for the file draftretriever-0.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for draftretriever-0.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 85491c9ad09a936b840171db292c02f1248d0748b4830d3d6d436e0c4499d4be
MD5 3e7b4d61f523e10ff45e2a9c15fe4050
BLAKE2b-256 9f9c01fb92a32204e89046482bda58bcd8901bb663572c9fe1d9e1a54fc32fc6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page