Skip to main content

Classical machine learning algorithms on the GPU.

Project description

Alt text

Scikit-JAX: Classical Machine Learning on the GPU

Welcome to Scikit-JAX, a machine learning library designed to leverage the power of GPUs through JAX for efficient and scalable classical machine learning algorithms. Our library provides implementations for a variety of classical machine learning techniques, optimized for performance and ease of use.

Features

  • Linear Regression: Implemented with options for different weight initialization methods and dropout regularization.
  • KMeans: Clustering algorithm to group data points into clusters.
  • Principal Component Analysis (PCA): Dimensionality reduction technique to simplify data while preserving essential features.
  • Multinomial Naive Bayes: Classifier suitable for discrete data, such as text classification tasks.
  • Gaussian Naive Bayes: Classifier for continuous data with a normal distribution assumption.

Installation

To install Scikit-JAX, you can use pip. The package is available on PyPI:

pip install scikit-jax

Usage

Here is a quick guide on how to use the key components of Scikit-JAX.

Linear Regression

from skjax.linear_model import LinearRegression

# Initialize the model
model = LinearRegression(weights_init='xavier', epochs=100, learning_rate=0.01)

# Fit the model
model.fit(X_train, y_train)

# Make predictions
predictions = model.predict(X_test)

# Plot losses
model.plot_losses()

K-Means

from skjax.clustering import KMeans

# Initialize the model
kmeans = KMeans(num_clusters=3)

# Fit the model
kmeans.fit(X_train)

# Predict clusters
clusters = kmeans.predict(X_test)

Gaussian Naive Bayes

from skjax.naive_bayes import GaussianNaiveBayes

# Initialize the model
nb = GaussianNaiveBayes()

# Fit the model
nb.fit(X_train, y_train)

# Make predictions
predictions = nb.predict(X_test)

License

Scikit-JAX is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scikit-jax-0.0.1.tar.gz (10.3 kB view details)

Uploaded Source

Built Distribution

scikit_jax-0.0.1-py3-none-any.whl (12.5 kB view details)

Uploaded Python 3

File details

Details for the file scikit-jax-0.0.1.tar.gz.

File metadata

  • Download URL: scikit-jax-0.0.1.tar.gz
  • Upload date:
  • Size: 10.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for scikit-jax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 9bc5c6ca94a802ecd0ae0af66663c530bbc475bd971556898e3c13a9ff0a8cee
MD5 7e5ea04dd2edfa434211855795b55377
BLAKE2b-256 b1d7507e6e795cb0b5758a8356485b84b7a1f6eac6f98c26eee81382ca8320ee

See more details on using hashes here.

File details

Details for the file scikit_jax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: scikit_jax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 12.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for scikit_jax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0d1ec4b0c27dde09220d2fcd21c0ed42378d144402ce1273d9b4126a10d383f5
MD5 4911513ec351fa69c05917bbe1e0daa7
BLAKE2b-256 83505587184a8d389be2d2039b8b5e0a652253cdafc0d76decfaca8abeec5183

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page