Skip to main content

Sparse Johnson-Lindenstrauss Transform with CUDA acceleration

Project description

sjlt

Sparse Johnson-Lindenstrauss Transform with CUDA acceleration for PyTorch.

Features

  • GPU-accelerated sparse random projections
  • Supports float32, float64, and bfloat16 data types
  • Optimized CUDA kernels for high performance
  • Easy integration with PyTorch workflows

Installation

Requirements

  • Python >= 3.8
  • PyTorch >= 1.9.0 with CUDA support
  • CUDA Toolkit (version compatible with your PyTorch installation)
  • C++ compiler (GCC 7-11 recommended)

Install from PyPI

pip install sjlt

Install from Source

git clone https://github.com/TRAIS-Lab/sjlt
cd sjlt
pip install -e .

Quick Start

Our sjlt implementation has the following parameters:

  • original_dim: input dimension
  • proj_dim: output dimension
  • c: sparsity parameter, i.e., non-zeros per column (default: 1)
  • threads: CUDA threads per block (default: 1024)
  • fixed_blocks: CUDA blocks to use (default: 84)

We note that the input is supposed to have batch_dim, i.e., input.shape() should be (batch_size, original_dim) and output.shape() will be (batch_size, proj_dim).

The following is a simple snippet of using our SJLT CUDA kernel:

import torch
from sjlt import SJLTProjection

# Create projection: 1024 -> 128 dimensions with sparsity 4
proj = SJLTProjection(original_dim=1024, proj_dim=128, c=4)

# Project some data
x = torch.randn(100, 1024, device='cuda')
y = proj(x)  # Shape: [100, 128]

print(f"Compression ratio: {proj.get_compression_ratio():.2f}x")
print(f"Sparsity: {proj.get_sparsity_ratio():.1%}")

Troubleshooting

If installation fails:

  1. Ensure CUDA toolkit is installed and nvcc is in PATH
  2. Check PyTorch CUDA compatibility: python -c "import torch; print(torch.cuda.is_available())"
  3. Try reinstalling: pip install sjlt --no-cache-dir --force-reinstall

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sjlt-0.1.0.tar.gz (9.3 kB view details)

Uploaded Source

File details

Details for the file sjlt-0.1.0.tar.gz.

File metadata

  • Download URL: sjlt-0.1.0.tar.gz
  • Upload date:
  • Size: 9.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for sjlt-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8191705de3a6fc62b7dc5173ff7c90aa183b480cbc60f896723ee0004e57687c
MD5 9837af86401b5c13b269d9e0a062e2d9
BLAKE2b-256 6a705acf06bfe564a3212748f72142555dbe6c9e836c4265d0f6bf1fb45f584a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page