Skip to main content

Composable kernels for scikit-learn implemented in JAX.

Project description

sklearn-jax-kernels

Build Status

Warning: This project is still in an early stage it could be that the API will change in the future, further functionality is still very limited to the use cases which defined the creation of the project (application to DNA sequences present in Biology).

Why?

Ever wanted to run a kernel-based model from scikit-learn on a relatively large dataset? If so you will have noticed, that this can take extraordinarily long and require huge amounts of memory, especially if you are using compositions of kernels (such as for example k1 * k2 + k3). This is due to the way Kernels are computed in scikit-learn: For each kernel, the complete kernel matrix is computed, and the compositions are then computed from the kernel matrices. Further, scikit-learn does not rely on an automatic differentiation framework for the computation of gradients though kernel operations.

Introduction

sklearn-jax-kernels was designed to circumvent these issues:

  • The utilization of JAX allows accelerating kernel computations through XLA optimizations, computation on GPUs and simplifies the computation of gradients though kernels
  • The composition of kernels takes place on a per-element basis, such that unnecessary copies can be optimized away by JAX compilation

The goal of sklearn-jax-kernels is to provide the same flexibility and ease of use as known from scikit-learn kernels while improving speed and allowing the faster design of new kernels through Automatic Differentiation.

The kernels in this package follow the scikit-learn kernel API.

Quickstart

A short demonstration of how the kernels can be used, inspired by the scikit-learn documentation.

from sklearn import datasets
import jax.numpy as jnp
from sklearn_jax_kernels import RBF, GaussianProcessClassifier

iris = datasets.load_iris()
X = jnp.asarray(iris.data)
y = jnp.array(iris.target, dtype=int)

kernel = 1. + RBF(length_scale=1.0)
gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y)

Here a further example demonstrating how kernels can be combined:

from sklearn_jax_kernels.base_kernels import RBF, NormalizedKernel
from sklearn_jax_kernels.structured.strings import SpectrumKernel

my_kernel = RBF(1.) * SpectrumKernel(n_gram_length=3)
my_kernel_2 = RBF(1.) + RBF(2.)
my_kernel_2 = NormalizedKernel(my_kernel_2)

Some further inspiration can be taken from the tests in the subfolder tests.

Implemented Kernels

  • Kernel compositions ($+,-,*,/$, exponentiation)
  • Kernels for real valued data:
    • RBF kernel
  • Kernels for same length strings:
    • SpectrumKernel
    • DistanceSpectrumKernel, SpectrumKernel with distance weight between matching substrings
    • ReverseComplement Spectrum kernel (relevant for applications in Biology when working with DNA sequences)

TODOs

  • Implement more fundamental Kernels
  • Implement jax compatible version of GaussianProcessRegressor
  • Optimize GaussianProcessClassifier for performance
  • Run benchmarks to show benefits in speed
  • Add fake "split" kernel which allows to apply different kernels to different parts of the input

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sklearn-jax-kernels-0.0.2.tar.gz (15.0 kB view details)

Uploaded Source

Built Distribution

sklearn_jax_kernels-0.0.2-py3-none-any.whl (15.5 kB view details)

Uploaded Python 3

File details

Details for the file sklearn-jax-kernels-0.0.2.tar.gz.

File metadata

  • Download URL: sklearn-jax-kernels-0.0.2.tar.gz
  • Upload date:
  • Size: 15.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.10 CPython/3.8.5 Darwin/19.6.0

File hashes

Hashes for sklearn-jax-kernels-0.0.2.tar.gz
Algorithm Hash digest
SHA256 3631c349eea1becff5c6c0b678ad5e537b657e98cfbfb658ca26de547dbbafdd
MD5 be285de4a92bd007886424fc6493303e
BLAKE2b-256 b95cbc04ffd7597c9000f7ace57752f1e0496e780b6bcc057ad965678691d84f

See more details on using hashes here.

File details

Details for the file sklearn_jax_kernels-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for sklearn_jax_kernels-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 ca90288465b3a6696b16fa17626521f6532957f9cbf8f130b9915e00e811c972
MD5 d6c45fd1e96604de3d3b9ed4aad2fdd9
BLAKE2b-256 998db5af7db70affb4f5744d9db1a5526be9980cf0c0d5d17786277f82fae19f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page