Skip to main content

sklearn's random projection with JAX to run on a GPU

Project description

JAX Random Projection Transformers

Using JAX to speed up sklearn's random projection transformers

Installation

Note: Installation with pip will install the CPU-only version of JAX

To use a GPU follow JAX's installation guide before installing jax-random_projections.

pip install jax-random_projections

Usage

from jax_random_projections.sparse import SparseRandomProjectionJAX

transfomer = SparseRandomProjectionJAX()
transfomer.fit_transform(X)

For the API documentation, refer to sklearn's SparseRandomProjection documentation. The only difference is that jax-random_projections currently only supports xla.DeviceArray and doesn't support dense_output=False and y for fit() This library currently only includes the SparseRandomProjection but a future release will also include GaussianRandomProjection.

jax-random_projections also includes SparseRandomProjectionJAXCached which uses a lru cache (maxsize=5) to speed up repeated calls by caching the random matrix for data with the same input dimension.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-random_projections-1.0.1.tar.gz (2.7 kB view hashes)

Uploaded Source

Built Distribution

jax_random_projections-1.0.1-py3-none-any.whl (4.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page