Sparse Johnson-Lindenstrauss Transform with CUDA acceleration
Project description
sjlt
Sparse Johnson-Lindenstrauss Transform with CUDA acceleration for PyTorch.
Features
- GPU-accelerated sparse random projections
- Supports float32, float64, and bfloat16 data types
- Optimized CUDA kernels for high performance
- Easy integration with PyTorch workflows
Installation
Requirements
- Python >= 3.8
- PyTorch >= 1.9.0 with CUDA support
- CUDA Toolkit (version compatible with your PyTorch installation)
- C++ compiler (GCC 7-11 recommended)
Install from PyPI
pip install sjlt
Install from Source
git clone https://github.com/TRAIS-Lab/sjlt
cd sjlt
pip install -e .
Quick Start
Our sjlt implementation has the following parameters:
original_dim: input dimensionproj_dim: output dimensionc: sparsity parameter, i.e., non-zeros per column (default:1)threads: CUDA threads per block (default:1024)fixed_blocks: CUDA blocks to use (default:84)
We note that the input is supposed to have
batch_dim, i.e.,input.shape()should be(batch_size, original_dim)andoutput.shape()will be(batch_size, proj_dim).
The following is a simple snippet of using our SJLT CUDA kernel:
import torch
from sjlt import SJLTProjection
# Create projection: 1024 -> 128 dimensions with sparsity 4
proj = SJLTProjection(original_dim=1024, proj_dim=128, c=4)
# Project some data
x = torch.randn(100, 1024, device='cuda')
y = proj(x) # Shape: [100, 128]
print(f"Compression ratio: {proj.get_compression_ratio():.2f}x")
print(f"Sparsity: {proj.get_sparsity_ratio():.1%}")
Troubleshooting
If installation fails:
- Ensure CUDA toolkit is installed and
nvccis inPATH - Check PyTorch CUDA compatibility:
python -c "import torch; print(torch.cuda.is_available())" - Try reinstalling:
pip install sjlt --no-cache-dir --force-reinstall
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
sjlt-0.1.0.tar.gz
(9.3 kB
view details)
File details
Details for the file sjlt-0.1.0.tar.gz.
File metadata
- Download URL: sjlt-0.1.0.tar.gz
- Upload date:
- Size: 9.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8191705de3a6fc62b7dc5173ff7c90aa183b480cbc60f896723ee0004e57687c
|
|
| MD5 |
9837af86401b5c13b269d9e0a062e2d9
|
|
| BLAKE2b-256 |
6a705acf06bfe564a3212748f72142555dbe6c9e836c4265d0f6bf1fb45f584a
|