Skip to main content

Utilities for embedding experiments with cross-platform array support

Project description

embedding_tools

PyPI version Python 3.8+ License: MIT

Utilities for embedding experiments with cross-platform array support

embedding_tools provides a backend-agnostic interface for working with embeddings across NumPy, MLX (Apple Silicon), and PyTorch. It includes memory management, configuration versioning, and similarity search utilities optimized for machine learning research.

Features

  • 🔄 Backend Abstraction: Seamlessly switch between NumPy, MLX, and PyTorch
  • 💾 Memory Management: Track and limit memory usage with EmbeddingStore
  • 🔍 Similarity Search: Built-in cosine similarity and nearest neighbor search
  • 📦 Dimension Slicing: Efficient truncation for Matryoshka embeddings
  • 🔐 Configuration Versioning: SHA-256 hashing for reproducible experiments
  • 🍎 Apple Silicon Optimized: Native MLX support for M-series Macs

Installation

# Core (NumPy only)
pip install embedding_tools

# With MLX (Apple Silicon)
pip install embedding_tools[mlx]

# With PyTorch
pip install embedding_tools[torch]

# With JAX (GPU/TPU support)
pip install embedding_tools[jax]

# Everything
pip install embedding_tools[all]

Development Installation

To contribute or use the latest development version:

git clone https://github.com/nborwankar/embedding_tools.git
cd embedding_tools
pip install -e ".[dev]"

Quick Start

Basic Array Operations

from embedding_tools import get_backend

# Auto-detect best available backend
backend = get_backend()  # Uses MLX > JAX > PyTorch > NumPy

# Or specify explicitly
backend = get_backend('numpy')  # CPU
backend = get_backend('mlx')    # Apple Silicon GPU (fastest on Mac)
backend = get_backend('jax')    # GPU/TPU with JIT compilation
backend = get_backend('torch')  # PyTorch (CUDA/MPS/CPU)

# Create arrays
embeddings = backend.create_array([[1, 2, 3], [4, 5, 6]])

# Compute similarities
query = backend.create_array([1, 2, 3])
sims = backend.cosine_similarity(query, embeddings)

# Slice to lower dimensions (for Matryoshka embeddings)
truncated = backend.slice_last_dim(embeddings, dim=2)

Memory-Safe Embedding Storage

from embedding_tools import EmbeddingStore
import numpy as np

# Create store with memory limit
store = EmbeddingStore(backend='mlx', max_memory_gb=10.0)

# Add embeddings
embeddings_1024d = np.random.randn(10000, 1024).astype(np.float32)
store.add_embeddings(embeddings_1024d, dimension=1024)

# Slice to lower dimensions (Matryoshka)
embeddings_128d = store.slice_to_dimension(source_dim=1024, target_dim=128)

# Similarity search
query = np.random.randn(1024).astype(np.float32)
similarities, indices = store.compute_similarity(
    query,
    dimension=1024,
    top_k=10
)

# Check memory usage
info = store.get_memory_info()
print(f"Total memory: {info['total_gb']:.2f} GB")

Configuration Versioning

from embedding_tools import compute_config_hash, compute_param_hash

# Hash a configuration dict
config = {
    'model': 'sentence-transformers/all-MiniLM-L6-v2',
    'dimension': 384,
    'batch_size': 32
}
hash_val = compute_config_hash(config)  # Returns 16-char hex string

# Or use keyword arguments
hash_val = compute_param_hash(
    model='all-MiniLM-L6-v2',
    dimension=384,
    batch_size=32
)

# Use for automatic cache invalidation
cache_key = f"embeddings_{hash_val}.npz"

Backend Comparison

Backend Hardware Speed Memory Installation
NumPy CPU 1x System RAM pip install embedding_tools
MLX Apple Silicon GPU 3-5x Unified memory pip install embedding_tools[mlx]
JAX GPU/TPU (Metal/CUDA/ROCm) 5-10x* GPU VRAM pip install embedding_tools[jax]
PyTorch CUDA/MPS/CPU 2-4x GPU VRAM pip install embedding_tools[torch]

*Speed with JIT compilation on repeated operations

Device Options for PyTorch:

  • device='cuda': NVIDIA GPUs (Linux/Windows)
  • device='mps': Apple Silicon GPU (macOS)
  • device='cpu': CPU fallback (all platforms)

Device Options for JAX:

  • device='gpu': GPU acceleration (Metal/CUDA/ROCm)
  • device='cpu': CPU fallback
  • device=None: Auto-detection (recommended)
# Explicit device configuration
from embedding_tools import get_backend, EmbeddingStore

# PyTorch: CUDA for NVIDIA GPUs (Linux production)
backend = get_backend('torch', device='cuda')
store = EmbeddingStore(backend='torch', max_memory_gb=40.0, device='cuda')

# PyTorch: MPS for Apple Silicon
backend = get_backend('torch', device='mps')
store = EmbeddingStore(backend='torch', max_memory_gb=20.0, device='mps')

# JAX: GPU acceleration (auto-detects Metal/CUDA/ROCm)
backend = get_backend('jax', device='gpu')
store = EmbeddingStore(backend='jax', max_memory_gb=20.0, device='gpu')

# Auto-detection (recommended)
backend = get_backend('torch')  # Automatically picks best device
backend = get_backend('jax')    # Automatically picks best device

Installation Validation

Run validation tests after installation:

pytest tests/test_installation.py -v

Or run directly:

python tests/test_installation.py

Expected output:

============================================================
embedding_tools Installation Validation Summary
============================================================
Version: 0.1.0
NumPy backend: ✓ Available
MLX backend: ✓ Available
Auto-detected backend: MLXBackend

All core functionality tests passed!
============================================================

Development

# Clone repository
git clone https://github.com/nborwankar/embedding_tools.git
cd embedding_tools

# Install in development mode
pip install -e ".[dev]"

# Run tests
pytest tests/ -v

# Format code
black .
isort .

# Lint
flake8 embedding_tools/

API Reference

Array Backends

get_backend(backend_name=None, device=None)

Get array backend instance.

Parameters:

  • backend_name (str, optional): 'numpy', 'mlx', 'jax', or 'torch'. Auto-detects if None.
  • device (str, optional): Device specification for JAX/PyTorch backends. Auto-detects if None.

Returns: ArrayBackend instance

ArrayBackend Methods

  • create_array(data, dtype=None) - Create array from data
  • zeros(shape, dtype=None) - Create zero-filled array
  • ones(shape, dtype=None) - Create one-filled array
  • random_normal(shape, mean=0.0, std=1.0) - Random normal array
  • dot(a, b) - Dot product
  • cosine_similarity(a, b) - Cosine similarity matrix
  • normalize(a, axis=-1) - L2 normalization
  • concatenate(arrays, axis=0) - Concatenate arrays
  • stack(arrays, axis=0) - Stack arrays
  • slice_last_dim(array, dim) - Slice to dimension
  • to_numpy(array) - Convert to NumPy
  • from_numpy(array) - Convert from NumPy
  • save(array, filepath) - Save to file
  • load(filepath) - Load from file
  • get_memory_usage(array) - Memory in bytes
  • get_shape(array) - Array shape
  • get_dtype(array) - Array dtype

Memory Management

EmbeddingStore(backend='numpy', max_memory_gb=8.0)

In-memory embedding storage with memory limits.

Methods:

  • add_embeddings(embeddings, dimension, text_ids=None, labels=None, metadata=None)
  • get_embeddings(dimension) - Retrieve embeddings
  • slice_to_dimension(source_dim, target_dim) - Matryoshka slicing
  • compute_similarity(query_emb, dimension, top_k=None) - Similarity search
  • get_available_dimensions() - List stored dimensions
  • get_total_memory_usage() - Total memory in bytes
  • get_memory_info() - Detailed memory statistics
  • save_to_disk(directory) - Save all embeddings
  • load_from_disk(directory) - Load all embeddings

Configuration

compute_config_hash(config)

Compute SHA-256 hash of configuration dictionary.

Parameters:

  • config (dict): Configuration dictionary

Returns: 16-character hex string

compute_param_hash(**kwargs)

Convenience function for hashing keyword arguments.

Returns: 16-character hex string

Use Cases

Matryoshka Embeddings

from embedding_tools import EmbeddingStore, get_backend

backend = get_backend('mlx')
store = EmbeddingStore(backend='mlx', max_memory_gb=20)

# Train model to produce 1024D embeddings
full_embeddings = model.encode(documents)  # (N, 1024)
store.add_embeddings(full_embeddings, dimension=1024)

# Get truncated versions for different use cases
embeddings_512 = store.slice_to_dimension(1024, 512)  # Moderate accuracy
embeddings_128 = store.slice_to_dimension(1024, 128)  # Fast search
embeddings_32 = store.slice_to_dimension(1024, 32)    # Ultra-fast

# Compare at different dimensions
for dim in [32, 128, 512, 1024]:
    sims, indices = store.compute_similarity(query, dim, top_k=10)
    print(f"{dim}D recall@10: {compute_recall(indices, ground_truth)}")

Cross-Platform Development

from embedding_tools import get_backend

# Development on Mac (uses MLX for speed)
if platform.system() == 'Darwin':
    backend = get_backend('mlx')
# Production on Linux (uses NumPy or CUDA)
else:
    backend = get_backend('numpy')

# Same code works everywhere
embeddings = backend.create_array(data)
similarities = backend.cosine_similarity(query, embeddings)

Experiment Versioning

from embedding_tools import compute_param_hash
import os

# Compute hash of experiment parameters
exp_hash = compute_param_hash(
    model='all-MiniLM-L6-v2',
    chunk_size=512,
    overlap=50,
    dimension=384
)

# Check if results exist
results_file = f'results_{exp_hash}.json'
if os.path.exists(results_file):
    print("Loading cached results...")
    results = load_results(results_file)
else:
    print("Running new experiment...")
    results = run_experiment()
    save_results(results, results_file)

License

MIT License - see LICENSE file for details.

Contributing

Contributions welcome! Please read CONTRIBUTING.md for guidelines.

Citation

If you use embedding_tools in your research, please cite:

@software{embedding_tools2025,
  title = {embedding_tools: Utilities for embedding experiments},
  author = {Nitin Borwankar},
  year = {2025},
  url = {https://github.com/nborwankar/embedding_tools}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

embedding_tools-0.1.2.tar.gz (27.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

embedding_tools-0.1.2-py3-none-any.whl (23.0 kB view details)

Uploaded Python 3

File details

Details for the file embedding_tools-0.1.2.tar.gz.

File metadata

  • Download URL: embedding_tools-0.1.2.tar.gz
  • Upload date:
  • Size: 27.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.4

File hashes

Hashes for embedding_tools-0.1.2.tar.gz
Algorithm Hash digest
SHA256 58c4e0317c23c82a0ae4e9d9b849e363f5cd7d6ab52c8d9856e15b6c3c1e2100
MD5 b828010ecd9778469024cfe1eb551f8e
BLAKE2b-256 9959b3748e8d7fcc573fc54348763621b32bb845ae1ea51b597b38b70c96c2ec

See more details on using hashes here.

File details

Details for the file embedding_tools-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for embedding_tools-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 50b89b7bf88e3db63698453ccc1e1ea99156c31442cf6f2017dd55a285a58a66
MD5 ab01fbc624850a4e0d127672272f87f7
BLAKE2b-256 71dcfb8aeb9bdd1c0feaafab7e753d2873bcd7cff3348e450ac3d54240dda33c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page