Efficiently computing & storing token n-grams from large corpora
Project description
Tokengrams
Tokengrams allows you to efficiently compute $n$-gram statistics for pre-tokenized text corpora used to train large language models. It does this not by explicitly pre-computing the $n$-gram counts for fixed $n$, but by creating a suffix array index which allows you to efficiently compute the count of an $n$-gram on the fly for any $n$.
Our code also allows you to turn your suffix array index into an efficient $n$-gram language model, which can be used to generate text or compute the perplexity of a given text.
The backend is written in Rust, and the Python bindings are generated using PyO3.
Installation
pip install tokengrams
Usage
Preparing data
Use a dataset of u16 or u32 tokens, or prepare one from a HuggingFace dataset.
# Get pre-tokenized dataset
from huggingface_hub import HfApi, hf_hub_download
hf_hub_download(
repo_id="EleutherAI/pile-standard-pythia-preshuffled",
repo_type="dataset",
filename="document-00000-of-00020.bin",
local_dir="."
)
# Tokenize HF dataset
from tokengrams import tokenize_hf_dataset
from datasets import load_dataset
from transformers import AutoTokenizer
tokenize_hf_dataset(
dataset=load_dataset("EleutherAI/lambada_openai", "en"),
tokenizer=AutoTokenizer.from_pretrained("EleutherAI/pythia-160m"),
output_path="lambada.bin",
text_key="text",
append_eod=True,
workers=1,
)
Building an index
from tokengrams import MemmapIndex
# Create a new index from an on-disk corpus of u16 tokens and save it to a .idx file.
# Set verbose to true to include a progress bar for the index sort.
index = MemmapIndex.build(
"document-00000-of-00020.bin",
"document-00000-of-00020.idx",
vocab=2**16,
verbose=True
)
# True for any valid index.
print(index.is_sorted())
# Get the count of "hello world" in the corpus.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
print(index.count(tokenizer.encode("hello world")))
# You can now load the index from disk later using __init__
index = MemmapIndex(
"document-00000-of-00020.bin",
"document-00000-of-00020.idx",
vocab=2**16
)
Using an index
# Count how often each token in the corpus succeeds "hello world".
print(index.count_next(tokenizer.encode("hello world")))
print(index.batch_count_next(
[tokenizer.encode("hello world"), tokenizer.encode("hello universe")]
))
# Get smoothed probabilities for query continuations
print(index.smoothed_probs(tokenizer.encode("hello world")))
print(index.batch_smoothed_probs(
[tokenizer.encode("hello world"), tokenizer.encode("hello universe")]
))
# Autoregressively sample 10 tokens using 5-gram language statistics. Initial
# gram statistics are derived from the query, with lower order gram statistics used
# until the sequence contains at least 5 tokens.
print(index.sample_unsmoothed(tokenizer.encode("hello world"), n=5, k=10, num_samples=20))
print(index.sample_smoothed(tokenizer.encode("hello world"), n=5, k=10, num_samples=20))
# Query whether the corpus contains "hello world"
print(index.contains(tokenizer.encode("hello world")))
# Get all n-grams beginning with "hello world" in the corpus
print(index.positions(tokenizer.encode("hello world")))
Scaling
Corpora small enough to fit in memory can use an InMemoryIndex:
from tokengrams import InMemoryIndex
tokens = [0, 1, 2, 3, 4]
index = InMemoryIndex(tokens, vocab=5)
Larger corpora must use a MemmapIndex.
Many systems struggle with memory mapping extremely large tables (e.g. 40 billion tokens), causing unexpected bus errors. To prevent this split the corpus into shards then use a ShardedMemmapIndex to sort and query the table shard by shard:
from tokengrams import ShardedMemmapIndex
from huggingface_hub import HfApi, hf_hub_download
files = [
file for file in HfApi().list_repo_files("EleutherAI/pile-standard-pythia-preshuffled", repo_type="dataset")
if file.endswith('.bin')
]
index_paths = []
for file in files:
hf_hub_download("EleutherAI/pile-standard-pythia-preshuffled", repo_type="dataset", filename=file, local_dir=".")
index_paths.append((file, f'{file.rstrip(".bin")}.idx'))
index = ShardedMemmapIndex.build(index_paths, vocab=2**16, verbose=True)
Tokens
Tokengrams builds indices from on-disk corpora of either u16 or u32 tokens, supporting a maximum vocabulary size of 232. In practice, however, vocabulary size is limited by the length of the largest word size vector the machine can allocate in memory.
Corpora with vocabulary sizes smaller than 216 must use u16 tokens.
Performance
Index build times for in-memory corpora scale inversely with the number of available CPU threads, whereas if the index reads from or writes to a file it is likely to be IO bound.
The time complexities of count_next(query) and sample_unsmoothed(query) are O(n log n), where n is ~ the number of completions for the query. The time complexity of sample_smoothed(query) is O(m n log n) where m is the n-gram order.
Development
cargo build
cargo test
Develop Python bindings:
pip install maturin
maturin develop
pytest
Support
The best way to get support is to open an issue on this repo or post in #interp-across-time in the EleutherAI Discord server. If you've used the library and have had a positive (or negative) experience, we'd love to hear from you!
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for tokengrams-0.3.3-cp311-none-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 06999438ffa6d7cd79aee7efc4081ee8d04a5bedf60eab37357010bdf869d067 |
|
MD5 | 7ca9d308cf78a2d6bff9b83e32d4c5a8 |
|
BLAKE2b-256 | 2d267b994f9c06919faaea736de45d40cde0203e0a4004ee1a0727b6792dba0f |
Hashes for tokengrams-0.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bb88191f078a4a384807030197247870804b0ab83c2b67c6bacacf0d0db7520b |
|
MD5 | 9e34255f1e971f7c7f9ac489f618548a |
|
BLAKE2b-256 | 088a04be29daf9cc3150df07fed38cebb8f44407042046a74f94e7c6d55c8c58 |
Hashes for tokengrams-0.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bcce30a79d835b328664a1c9fff83e8307895ee700d47a115dad0fde3bcd9d5c |
|
MD5 | ba3e1c47b87ac68b3de311a924508e24 |
|
BLAKE2b-256 | 17390b95fd06c4ca1f6e23c2e4c0b088452cddb8308a540d0ddead3184555fd7 |
Hashes for tokengrams-0.3.3-cp311-cp311-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 340d9cad5ed5e95ddeb6b014f55483f944a30e3f915da85e2c195d56606fc987 |
|
MD5 | 44079a8ce6810008c0e0020d30ec0c80 |
|
BLAKE2b-256 | 04c6bafe808bd57d74d00d4d4054099a72116f8c7af50e978cdc148aa7ff2336 |
Hashes for tokengrams-0.3.3-cp310-none-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | db2957e7540931f007d90dfaf5b6bed5e33897e1c7b7db2679c3f455add6e459 |
|
MD5 | c9339535b65becfac8ca33e8fd3f756e |
|
BLAKE2b-256 | 276c7e5a37adbe5319b2a368f8b5d8371401b592df5aab36bc2dcb4582afae1e |
Hashes for tokengrams-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a50e3395b389536f1623dedbc099718e019f49b1dc3327e3a0e6686d3f2244c7 |
|
MD5 | b9e9ce76313a437b5008581a48e2deaa |
|
BLAKE2b-256 | 23877d83849ce20620f0b9d88f61a4b523f768b50689c04b979baaaf2eada09f |
Hashes for tokengrams-0.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ad6892b8175414675d53e304c7e4eec34b597726493b03ccc0cfbd51a48d6118 |
|
MD5 | 7d413f67a5f5fa9fd98d56935a63ccd9 |
|
BLAKE2b-256 | 1fc12cdc47d6f8c0411ff65bfc1f487a33399b4200e5ad3f804ab72e338771f1 |
Hashes for tokengrams-0.3.3-cp310-cp310-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ff0177f111d0fb562d88f973f6f683ca0a5f57ae7472eca9dd446de712b0cf4a |
|
MD5 | 0e538b2097c98090089ef2b69ec6a04a |
|
BLAKE2b-256 | 001913436430bd1ada3cd846dfa54bbcf8fcebe7a6da8a816a5c0fd1b562563f |