Classical machine learning algorithms on the GPU/TPU.
Project description
Scikit-JAX: Classical Machine Learning on the GPU
Welcome to Scikit-JAX, a machine learning library designed to leverage the power of GPUs through JAX for efficient and scalable classical machine learning algorithms. Our library provides implementations for a variety of classical machine learning techniques, optimized for performance and ease of use.
Features
- Linear Regression: Implemented with options for different weight initialization methods and dropout regularization.
- KMeans: Clustering algorithm to group data points into clusters.
- Principal Component Analysis (PCA): Dimensionality reduction technique to simplify data while preserving essential features.
- Multinomial Naive Bayes: Classifier suitable for discrete data, such as text classification tasks.
- Gaussian Naive Bayes: Classifier for continuous data with a normal distribution assumption.
Installation
To install Scikit-JAX, you can use pip. The package is available on PyPI:
pip install scikit-jax
Usage
Here is a quick guide on how to use the key components of Scikit-JAX.
Linear Regression
from skjax.linear_model import LinearRegression
# Initialize the model
model = LinearRegression(weights_init='xavier', epochs=100, learning_rate=0.01)
# Fit the model
model.fit(X_train, y_train)
# Make predictions
predictions = model.predict(X_test)
# Plot losses
model.plot_losses()
K-Means
from skjax.clustering import KMeans
# Initialize the model
kmeans = KMeans(num_clusters=3)
# Fit the model
kmeans.fit(X_train)
Gaussian Naive Bayes
from skjax.naive_bayes import GaussianNaiveBayes
# Initialize the model
nb = GaussianNaiveBayes()
# Fit the model
nb.fit(X_train, y_train)
# Make predictions
predictions = nb.predict(X_test)
License
Scikit-JAX is licensed under the MIT License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file scikit_jax-0.0.3.dev0.tar.gz
.
File metadata
- Download URL: scikit_jax-0.0.3.dev0.tar.gz
- Upload date:
- Size: 13.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d2a50514c39a47a4bb7e036236db80b57a86fbf188f9d84598946e7e32c2e90a |
|
MD5 | a4fb6cf82708576fedd7f038c1f3ca1f |
|
BLAKE2b-256 | 0e710456ba3574878f25652ac364773edcaee7f9f33e2a296a830b94c8548002 |
File details
Details for the file scikit_jax-0.0.3.dev0-py3-none-any.whl
.
File metadata
- Download URL: scikit_jax-0.0.3.dev0-py3-none-any.whl
- Upload date:
- Size: 16.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d32ebf3b24251dc97b9971725b3996c607d63378706962e68386967e5b814ef6 |
|
MD5 | ff6abc37339277147c1843102ef4cddf |
|
BLAKE2b-256 | 83ddd83752e9307b8f3c5950894b7cf54ba4df3825f33ce2834ab97cc4919064 |