Sparse Johnson-Lindenstrauss Transform with CUDA acceleration
Project description
Sparse Johnson-Lindenstrauss Transform CUDA Kernel
This is a simple repository for Sparse Johnson-Lindenstrauss Transform with CUDA acceleration for PyTorch.
Features
- GPU-accelerated sparse random projections
- Supports float32, float64, and bfloat16 data types
- Optimized CUDA kernels for high performance
- Easy integration with PyTorch workflows
Installation
Requirements
- Python >= 3.8
- PyTorch >= 1.9.0 with CUDA support
- CUDA Toolkit (version compatible with your PyTorch installation)
- C++ compiler (GCC 7-11 recommended)
Install from PyPI
pip install sjlt
Install from Source
git clone https://github.com/TRAIS-Lab/sjlt
cd sjlt
pip install -e .
Due to the default
pip install -e .isolation build behavior, it might help to usepip install --no-build-isolation -e .when you see something like:RuntimeError: The detected CUDA version (11.8) mismatches the version that was used to compile PyTorch (12.6). Please make sure to use the same CUDA versions.
Quick Start
Our sjlt implementation has the following parameters:
original_dim: input dimensionproj_dim: output dimensionc: sparsity parameter, i.e., non-zeros per column (default:1)threads: CUDA threads per block (default:1024)fixed_blocks: CUDA blocks to use (default:84)
We note that the input is supposed to have
batch_dim, i.e.,input.shape()should be(batch_size, original_dim)andoutput.shape()will be(batch_size, proj_dim).
The following is a simple snippet of using our SJLT CUDA kernel:
import torch
from sjlt import SJLTProjection
# Create projection: 1024 -> 128 dimensions with sparsity 4
proj = SJLTProjection(original_dim=1024, proj_dim=128, c=4)
# Project some data
x = torch.randn(100, 1024, device='cuda')
y = proj(x) # Shape: [100, 128]
Troubleshooting
If installation fails:
- Ensure CUDA toolkit is installed and
nvccis inPATH - Check PyTorch CUDA compatibility:
python -c "import torch; print(torch.cuda.is_available())" - Try reinstalling:
pip install sjlt --no-cache-dir --force-reinstall
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file sjlt-0.1.1.tar.gz.
File metadata
- Download URL: sjlt-0.1.1.tar.gz
- Upload date:
- Size: 9.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0790f21b212ec8c146a21e6a806a48645a5f895d616946c2a4be6633385915eb
|
|
| MD5 |
9ba7e7f23c9a9df5200a2fb2cf30a27c
|
|
| BLAKE2b-256 |
aaf842cb44c1e6a642c1a70dd56b1e9bda5abbdbb709a0720e03e6e33c0add08
|