sklearn's random projection with JAX to run on a GPU
Project description
JAX Random Projection Transformers
Using JAX to speed up sklearn's random projection transformers
Installation
Note: Installation with pip will install the CPU-only version of JAX
To use a GPU follow JAX's installation guide before installing jax-random_projections
.
pip install jax-random_projections
Usage
from jax_random_projections.sparse import SparseRandomProjectionJAX
transfomer = SparseRandomProjectionJAX()
transfomer.fit_transform(X)
For the API documentation, refer to sklearn's SparseRandomProjection documentation.
The only difference is that jax-random_projections
currently only supports xla.DeviceArray
and doesn't support dense_output=False
and y
for fit()
This library currently only includes the SparseRandomProjection
but a future release will also include GaussianRandomProjection
.
jax-random_projections
also includes SparseRandomProjectionJAXCached
which uses a lru cache (maxsize=5
) to speed up repeated calls by caching the random matrix for data with the same input dimension.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for jax-random_projections-1.0.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7eef6e89bedfe44429c149bdc367934c201d4706cc364e640320976ecc7d2ab1 |
|
MD5 | 7c5df0e2ead15bd9a37c38439ebe51a7 |
|
BLAKE2b-256 | 0a59e9b52356e516d047a1a374e143519631a5adf3b99f055b04d90780eb6dc5 |
Hashes for jax_random_projections-1.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 582fca45fcce4c80b8a2fd4879738a3a299d7ac672e08092d07fa307d06a8980 |
|
MD5 | 8ab1f8ed6f0732a2d02982328d07c853 |
|
BLAKE2b-256 | 7d4e126cd335f29e8e6e13468dcaadcd05bad4c0d9e1f7eebf2ff5fe764b9721 |