Utilities for embedding experiments with cross-platform array support
Project description
embedding_tools
Utilities for embedding experiments with cross-platform array support
embedding_tools provides a backend-agnostic interface for working with embeddings across NumPy, MLX (Apple Silicon), and PyTorch. It includes memory management, configuration versioning, and similarity search utilities optimized for machine learning research.
Features
- 🔄 Backend Abstraction: Seamlessly switch between NumPy, MLX, and PyTorch
- 💾 Memory Management: Track and limit memory usage with
EmbeddingStore - 🔍 Similarity Search: Built-in cosine similarity and nearest neighbor search
- 📦 Dimension Slicing: Efficient truncation for Matryoshka embeddings
- 🔐 Configuration Versioning: SHA-256 hashing for reproducible experiments
- 🍎 Apple Silicon Optimized: Native MLX support for M-series Macs
Installation
# Core (NumPy only)
pip install embedding_tools
# With MLX (Apple Silicon)
pip install embedding_tools[mlx]
# With PyTorch
pip install embedding_tools[torch]
# With JAX (GPU/TPU support)
pip install embedding_tools[jax]
# Everything
pip install embedding_tools[all]
Development Installation
To contribute or use the latest development version:
git clone https://github.com/nborwankar/embedding_tools.git
cd embedding_tools
pip install -e ".[dev]"
Quick Start
Basic Array Operations
from embedding_tools import get_backend
# Auto-detect best available backend
backend = get_backend() # Uses MLX > JAX > PyTorch > NumPy
# Or specify explicitly
backend = get_backend('numpy') # CPU
backend = get_backend('mlx') # Apple Silicon GPU (fastest on Mac)
backend = get_backend('jax') # GPU/TPU with JIT compilation
backend = get_backend('torch') # PyTorch (CUDA/MPS/CPU)
# Create arrays
embeddings = backend.create_array([[1, 2, 3], [4, 5, 6]])
# Compute similarities
query = backend.create_array([1, 2, 3])
sims = backend.cosine_similarity(query, embeddings)
# Slice to lower dimensions (for Matryoshka embeddings)
truncated = backend.slice_last_dim(embeddings, dim=2)
Memory-Safe Embedding Storage
from embedding_tools import EmbeddingStore
import numpy as np
# Create store with memory limit
store = EmbeddingStore(backend='mlx', max_memory_gb=10.0)
# Add embeddings
embeddings_1024d = np.random.randn(10000, 1024).astype(np.float32)
store.add_embeddings(embeddings_1024d, dimension=1024)
# Slice to lower dimensions (Matryoshka)
embeddings_128d = store.slice_to_dimension(source_dim=1024, target_dim=128)
# Similarity search
query = np.random.randn(1024).astype(np.float32)
similarities, indices = store.compute_similarity(
query,
dimension=1024,
top_k=10
)
# Check memory usage
info = store.get_memory_info()
print(f"Total memory: {info['total_gb']:.2f} GB")
Configuration Versioning
from embedding_tools import compute_config_hash, compute_param_hash
# Hash a configuration dict
config = {
'model': 'sentence-transformers/all-MiniLM-L6-v2',
'dimension': 384,
'batch_size': 32
}
hash_val = compute_config_hash(config) # Returns 16-char hex string
# Or use keyword arguments
hash_val = compute_param_hash(
model='all-MiniLM-L6-v2',
dimension=384,
batch_size=32
)
# Use for automatic cache invalidation
cache_key = f"embeddings_{hash_val}.npz"
Backend Comparison
| Backend | Hardware | Speed | Memory | Installation |
|---|---|---|---|---|
| NumPy | CPU | 1x | System RAM | pip install embedding_tools |
| MLX | Apple Silicon GPU | 3-5x | Unified memory | pip install embedding_tools[mlx] |
| JAX | GPU/TPU (Metal/CUDA/ROCm) | 5-10x* | GPU VRAM | pip install embedding_tools[jax] |
| PyTorch | CUDA/MPS/CPU | 2-4x | GPU VRAM | pip install embedding_tools[torch] |
*Speed with JIT compilation on repeated operations
Device Options for PyTorch:
device='cuda': NVIDIA GPUs (Linux/Windows)device='mps': Apple Silicon GPU (macOS)device='cpu': CPU fallback (all platforms)
Device Options for JAX:
device='gpu': GPU acceleration (Metal/CUDA/ROCm)device='cpu': CPU fallbackdevice=None: Auto-detection (recommended)
# Explicit device configuration
from embedding_tools import get_backend, EmbeddingStore
# PyTorch: CUDA for NVIDIA GPUs (Linux production)
backend = get_backend('torch', device='cuda')
store = EmbeddingStore(backend='torch', max_memory_gb=40.0, device='cuda')
# PyTorch: MPS for Apple Silicon
backend = get_backend('torch', device='mps')
store = EmbeddingStore(backend='torch', max_memory_gb=20.0, device='mps')
# JAX: GPU acceleration (auto-detects Metal/CUDA/ROCm)
backend = get_backend('jax', device='gpu')
store = EmbeddingStore(backend='jax', max_memory_gb=20.0, device='gpu')
# Auto-detection (recommended)
backend = get_backend('torch') # Automatically picks best device
backend = get_backend('jax') # Automatically picks best device
Installation Validation
Run validation tests after installation:
pytest tests/test_installation.py -v
Or run directly:
python tests/test_installation.py
Expected output:
============================================================
embedding_tools Installation Validation Summary
============================================================
Version: 0.1.0
NumPy backend: ✓ Available
MLX backend: ✓ Available
Auto-detected backend: MLXBackend
All core functionality tests passed!
============================================================
Development
# Clone repository
git clone https://github.com/nborwankar/embedding_tools.git
cd embedding_tools
# Install in development mode
pip install -e ".[dev]"
# Run tests
pytest tests/ -v
# Format code
black .
isort .
# Lint
flake8 embedding_tools/
API Reference
Array Backends
get_backend(backend_name=None, device=None)
Get array backend instance.
Parameters:
backend_name(str, optional): 'numpy', 'mlx', 'jax', or 'torch'. Auto-detects if None.device(str, optional): Device specification for JAX/PyTorch backends. Auto-detects if None.
Returns: ArrayBackend instance
ArrayBackend Methods
create_array(data, dtype=None)- Create array from datazeros(shape, dtype=None)- Create zero-filled arrayones(shape, dtype=None)- Create one-filled arrayrandom_normal(shape, mean=0.0, std=1.0)- Random normal arraydot(a, b)- Dot productcosine_similarity(a, b)- Cosine similarity matrixnormalize(a, axis=-1)- L2 normalizationconcatenate(arrays, axis=0)- Concatenate arraysstack(arrays, axis=0)- Stack arraysslice_last_dim(array, dim)- Slice to dimensionto_numpy(array)- Convert to NumPyfrom_numpy(array)- Convert from NumPysave(array, filepath)- Save to fileload(filepath)- Load from fileget_memory_usage(array)- Memory in bytesget_shape(array)- Array shapeget_dtype(array)- Array dtype
Memory Management
EmbeddingStore(backend='numpy', max_memory_gb=8.0)
In-memory embedding storage with memory limits.
Methods:
add_embeddings(embeddings, dimension, text_ids=None, labels=None, metadata=None)get_embeddings(dimension)- Retrieve embeddingsslice_to_dimension(source_dim, target_dim)- Matryoshka slicingcompute_similarity(query_emb, dimension, top_k=None)- Similarity searchget_available_dimensions()- List stored dimensionsget_total_memory_usage()- Total memory in bytesget_memory_info()- Detailed memory statisticssave_to_disk(directory)- Save all embeddingsload_from_disk(directory)- Load all embeddings
Configuration
compute_config_hash(config)
Compute SHA-256 hash of configuration dictionary.
Parameters:
config(dict): Configuration dictionary
Returns: 16-character hex string
compute_param_hash(**kwargs)
Convenience function for hashing keyword arguments.
Returns: 16-character hex string
Use Cases
Matryoshka Embeddings
from embedding_tools import EmbeddingStore, get_backend
backend = get_backend('mlx')
store = EmbeddingStore(backend='mlx', max_memory_gb=20)
# Train model to produce 1024D embeddings
full_embeddings = model.encode(documents) # (N, 1024)
store.add_embeddings(full_embeddings, dimension=1024)
# Get truncated versions for different use cases
embeddings_512 = store.slice_to_dimension(1024, 512) # Moderate accuracy
embeddings_128 = store.slice_to_dimension(1024, 128) # Fast search
embeddings_32 = store.slice_to_dimension(1024, 32) # Ultra-fast
# Compare at different dimensions
for dim in [32, 128, 512, 1024]:
sims, indices = store.compute_similarity(query, dim, top_k=10)
print(f"{dim}D recall@10: {compute_recall(indices, ground_truth)}")
Cross-Platform Development
from embedding_tools import get_backend
# Development on Mac (uses MLX for speed)
if platform.system() == 'Darwin':
backend = get_backend('mlx')
# Production on Linux (uses NumPy or CUDA)
else:
backend = get_backend('numpy')
# Same code works everywhere
embeddings = backend.create_array(data)
similarities = backend.cosine_similarity(query, embeddings)
Experiment Versioning
from embedding_tools import compute_param_hash
import os
# Compute hash of experiment parameters
exp_hash = compute_param_hash(
model='all-MiniLM-L6-v2',
chunk_size=512,
overlap=50,
dimension=384
)
# Check if results exist
results_file = f'results_{exp_hash}.json'
if os.path.exists(results_file):
print("Loading cached results...")
results = load_results(results_file)
else:
print("Running new experiment...")
results = run_experiment()
save_results(results, results_file)
License
MIT License - see LICENSE file for details.
Contributing
Contributions welcome! Please read CONTRIBUTING.md for guidelines.
Citation
If you use embedding_tools in your research, please cite:
@software{embedding_tools2025,
title = {embedding_tools: Utilities for embedding experiments},
author = {Nitin Borwankar},
year = {2025},
url = {https://github.com/nborwankar/embedding_tools}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file embedding_tools-0.1.2.tar.gz.
File metadata
- Download URL: embedding_tools-0.1.2.tar.gz
- Upload date:
- Size: 27.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
58c4e0317c23c82a0ae4e9d9b849e363f5cd7d6ab52c8d9856e15b6c3c1e2100
|
|
| MD5 |
b828010ecd9778469024cfe1eb551f8e
|
|
| BLAKE2b-256 |
9959b3748e8d7fcc573fc54348763621b32bb845ae1ea51b597b38b70c96c2ec
|
File details
Details for the file embedding_tools-0.1.2-py3-none-any.whl.
File metadata
- Download URL: embedding_tools-0.1.2-py3-none-any.whl
- Upload date:
- Size: 23.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
50b89b7bf88e3db63698453ccc1e1ea99156c31442cf6f2017dd55a285a58a66
|
|
| MD5 |
ab01fbc624850a4e0d127672272f87f7
|
|
| BLAKE2b-256 |
71dcfb8aeb9bdd1c0feaafab7e753d2873bcd7cff3348e450ac3d54240dda33c
|