Skip to main content

Sparse Johnson-Lindenstrauss Transform with CUDA acceleration

Project description

Sparse Johnson-Lindenstrauss Transform CUDA Kernel

This is a simple repository for Sparse Johnson-Lindenstrauss Transform with CUDA acceleration for PyTorch.

Features

  • GPU-accelerated sparse random projections
  • Supports float32, float64, and bfloat16 data types
  • Optimized CUDA kernels for high performance
  • Easy integration with PyTorch workflows

Installation

Requirements

  • Python >= 3.8
  • PyTorch >= 1.9.0 with CUDA support
  • CUDA Toolkit (version compatible with your PyTorch installation)
  • C++ compiler (GCC 7-11 recommended)

Install from PyPI

pip install sjlt

Install from Source

git clone https://github.com/TRAIS-Lab/sjlt
cd sjlt
pip install -e .

Due to the default pip install -e . isolation build behavior, it might help to use pip install --no-build-isolation -e . when you see something like:

RuntimeError:
The detected CUDA version (11.8) mismatches the version that was used to compile
PyTorch (12.6). Please make sure to use the same CUDA versions.

Quick Start

Our sjlt implementation has the following parameters:

  • original_dim: input dimension
  • proj_dim: output dimension
  • c: sparsity parameter, i.e., non-zeros per column (default: 1)
  • threads: CUDA threads per block (default: 1024)
  • fixed_blocks: CUDA blocks to use (default: 84)

We note that the input is supposed to have batch_dim, i.e., input.shape() should be (batch_size, original_dim) and output.shape() will be (batch_size, proj_dim).

The following is a simple snippet of using our SJLT CUDA kernel:

import torch
from sjlt import SJLTProjection

# Create projection: 1024 -> 128 dimensions with sparsity 4
proj = SJLTProjection(original_dim=1024, proj_dim=128, c=4)

# Project some data
x = torch.randn(100, 1024, device='cuda')
y = proj(x)  # Shape: [100, 128]

Troubleshooting

If installation fails:

  1. Ensure CUDA toolkit is installed and nvcc is in PATH
  2. Check PyTorch CUDA compatibility: python -c "import torch; print(torch.cuda.is_available())"
  3. Try reinstalling: pip install sjlt --no-cache-dir --force-reinstall

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sjlt-0.1.1.tar.gz (9.5 kB view details)

Uploaded Source

File details

Details for the file sjlt-0.1.1.tar.gz.

File metadata

  • Download URL: sjlt-0.1.1.tar.gz
  • Upload date:
  • Size: 9.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for sjlt-0.1.1.tar.gz
Algorithm Hash digest
SHA256 0790f21b212ec8c146a21e6a806a48645a5f895d616946c2a4be6633385915eb
MD5 9ba7e7f23c9a9df5200a2fb2cf30a27c
BLAKE2b-256 aaf842cb44c1e6a642c1a70dd56b1e9bda5abbdbb709a0720e03e6e33c0add08

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page