Skip to main content

A JAX-based implementation of Kolmogorov-Arnold Networks

Project description

jaxKAN

A JAX implementation of the original Kolmogorov-Arnold Networks (KANs), using the Flax and Optax frameworks for neural networks and optimization, respectively. Our adaptation is based on the original pykan, however we also included a built-in grid extension routine, which does not simply perform an adaptation of the grid based on the inputs, but also extends its size.

Installation

jaxKAN is available as a PyPI package. For installation, simply run

pip3 install jaxkan

The default installation requires jax[cpu], but there is also a gpu version which will install jax[cuda12] as a dependency.

Why not more efficient?

Despite their overall potential in the Deep Learning field, the authors of KANs emphasized their performance when it comes to scientific computing, in tasks such as Symbolic Regression or solving PDEs. This is why we put emphasis on preserving their original form, albeit less computationally efficient, as it allows the user to utilize the full regularization terms presented in the arXiv pre-print and not the "mock" regularization terms presented, for instance, in the efficient-kan implementation.

Why JAX?

Because speed + scientific computing. Need we say more? Plus, even though all tests were performed on CPU, in JAX it is more than straightforward to switch to GPU.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxkan-0.1.4.tar.gz (16.1 kB view hashes)

Uploaded Source

Built Distribution

jaxkan-0.1.4-py3-none-any.whl (21.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page