sklearn's random projection with JAX to run on a GPU
Project description
JAX Random Projection Transformers
Using JAX to speed up sklearn's random projection transformers
Installation
Note: Installation with pip will install the CPU-only version of JAX
To use a GPU follow JAX's installation guide before installing jax-random_projections.
pip install jax-random_projections
Usage
from jax_random_projections.sparse import SparseRandomProjectionJAX
transfomer = SparseRandomProjectionJAX()
transfomer.fit_transform(X)
For the API documentation, refer to sklearn's SparseRandomProjection documentation.
The only difference is that jax-random_projections currently only supports xla.DeviceArray and doesn't support dense_output=False and y for fit()
This library currently only includes the SparseRandomProjection but a future release will also include GaussianRandomProjection.
jax-random_projections also includes SparseRandomProjectionJAXCached which uses a lru cache (maxsize=5) to speed up repeated calls by caching the random matrix for data with the same input dimension.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax-random_projections-1.0.1.tar.gz.
File metadata
- Download URL: jax-random_projections-1.0.1.tar.gz
- Upload date:
- Size: 2.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.4.0 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.7.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7eef6e89bedfe44429c149bdc367934c201d4706cc364e640320976ecc7d2ab1
|
|
| MD5 |
7c5df0e2ead15bd9a37c38439ebe51a7
|
|
| BLAKE2b-256 |
0a59e9b52356e516d047a1a374e143519631a5adf3b99f055b04d90780eb6dc5
|
File details
Details for the file jax_random_projections-1.0.1-py3-none-any.whl.
File metadata
- Download URL: jax_random_projections-1.0.1-py3-none-any.whl
- Upload date:
- Size: 4.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.4.0 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.7.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
582fca45fcce4c80b8a2fd4879738a3a299d7ac672e08092d07fa307d06a8980
|
|
| MD5 |
8ab1f8ed6f0732a2d02982328d07c853
|
|
| BLAKE2b-256 |
7d4e126cd335f29e8e6e13468dcaadcd05bad4c0d9e1f7eebf2ff5fe764b9721
|