Skip to main content

A JAX-based implementation of Kolmogorov-Arnold Networks

Project description

jaxKAN

A JAX implementation of the original Kolmogorov-Arnold Networks (KANs), using the Flax and Optax frameworks for neural networks and optimization, respectively. Our adaptation is based on the original pykan, however we also included a built-in grid extension routine, which does not simply perform an adaptation of the grid based on the inputs, but also extends its size.

Installation

jaxKAN is available as a PyPI package. For installation, simply run

pip3 install jaxkan

The default installation requires jax[cpu], but there is also a gpu version which will install jax[cuda12] as a dependency.

Why not more efficient?

Despite their overall potential in the Deep Learning field, the authors of KANs emphasized their performance when it comes to scientific computing, in tasks such as Symbolic Regression or solving PDEs. This is why we put emphasis on preserving their original form, albeit less computationally efficient, as it allows the user to utilize the full regularization terms presented in the arXiv pre-print and not the "mock" regularization terms presented, for instance, in the efficient-kan implementation.

Citation

If you utilized jaxKAN for your own academic work, please consider using the following citation, which is the paper introducing the framework:

@misc{rigas2024adaptivetraininggriddependentphysicsinformed,
      title={Adaptive Training of Grid-Dependent Physics-Informed Kolmogorov-Arnold Networks}, 
      author={Spyros Rigas and Michalis Papachristou and Theofilos Papadopoulos and Fotios Anagnostopoulos and Georgios Alexandridis},
      year={2024},
      eprint={2407.17611},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2407.17611}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxkan-0.1.7.tar.gz (15.6 kB view details)

Uploaded Source

Built Distribution

jaxkan-0.1.7-py3-none-any.whl (20.5 kB view details)

Uploaded Python 3

File details

Details for the file jaxkan-0.1.7.tar.gz.

File metadata

  • Download URL: jaxkan-0.1.7.tar.gz
  • Upload date:
  • Size: 15.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for jaxkan-0.1.7.tar.gz
Algorithm Hash digest
SHA256 8257ae9aa7a0992a73b186decba12212924f5cb494891eb3fff2f2174fa2e6ab
MD5 89d23bc75a6cd568dbc802463d2c40fd
BLAKE2b-256 46be97e7dc0fefecda2c64441ee65bf50531848d13954743157e80a49aa18278

See more details on using hashes here.

File details

Details for the file jaxkan-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: jaxkan-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 20.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for jaxkan-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 61d078e608584de0945d45d5ef2a7d98a79349c0129a18d4b69d2f37119f4380
MD5 6e5b246ebda2ea2056c56113ec933506
BLAKE2b-256 26216b978bf7d8c7c0ca9628767264cbec82ef6697ce45a85cf846b6cadf2e77

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page