Skip to main content

sklearn's random projection with JAX to run on a GPU

Project description

JAX Random Projection Transformers

Using JAX to speed up sklearn's random projection transformers

Installation

Note: Installation with pip will install the CPU-only version of JAX

To use a GPU follow JAX's installation guide before installing jax-random_projections.

pip install jax-random_projections

Usage

from jax_random_projections.sparse import SparseRandomProjectionJAX

transfomer = SparseRandomProjectionJAX()
transfomer.fit_transform(X)

For the API documentation, refer to sklearn's SparseRandomProjection documentation. The only difference is that jax-random_projections currently only supports xla.DeviceArray and doesn't support dense_output=False and y for fit() This library currently only includes the SparseRandomProjection but a future release will also include GaussianRandomProjection.

jax-random_projections also includes SparseRandomProjectionJAXCached which uses a lru cache (maxsize=5) to speed up repeated calls by caching the random matrix for data with the same input dimension.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-random_projections-1.0.1.tar.gz (2.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_random_projections-1.0.1-py3-none-any.whl (4.0 kB view details)

Uploaded Python 3

File details

Details for the file jax-random_projections-1.0.1.tar.gz.

File metadata

  • Download URL: jax-random_projections-1.0.1.tar.gz
  • Upload date:
  • Size: 2.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.4.0 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.7.7

File hashes

Hashes for jax-random_projections-1.0.1.tar.gz
Algorithm Hash digest
SHA256 7eef6e89bedfe44429c149bdc367934c201d4706cc364e640320976ecc7d2ab1
MD5 7c5df0e2ead15bd9a37c38439ebe51a7
BLAKE2b-256 0a59e9b52356e516d047a1a374e143519631a5adf3b99f055b04d90780eb6dc5

See more details on using hashes here.

File details

Details for the file jax_random_projections-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: jax_random_projections-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 4.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.4.0 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.7.7

File hashes

Hashes for jax_random_projections-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 582fca45fcce4c80b8a2fd4879738a3a299d7ac672e08092d07fa307d06a8980
MD5 8ab1f8ed6f0732a2d02982328d07c853
BLAKE2b-256 7d4e126cd335f29e8e6e13468dcaadcd05bad4c0d9e1f7eebf2ff5fe764b9721

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page