A JAX-based implementation of Kolmogorov-Arnold Networks
Project description
jaxKAN
A JAX implementation of the original Kolmogorov-Arnold Networks (KANs), using the Flax and Optax frameworks for neural networks and optimization, respectively. Our adaptation is based on the original pykan, however we also included a built-in grid extension routine, which does not simply perform an adaptation of the grid based on the inputs, but also extends its size.
Installation
jaxKAN
is available as a PyPI package. For installation, simply run
pip3 install jaxkan
The default installation requires jax[cpu]
, but there is also a gpu
version which will install jax[cuda12]
as a dependency.
Why not more efficient?
Despite their overall potential in the Deep Learning field, the authors of KANs emphasized their performance when it comes to scientific computing, in tasks such as Symbolic Regression or solving PDEs. This is why we put emphasis on preserving their original form, albeit less computationally efficient, as it allows the user to utilize the full regularization terms presented in the arXiv pre-print and not the "mock" regularization terms presented, for instance, in the efficient-kan implementation.
Citation
If you utilized jaxKAN
for your own academic work, please consider using the following citation, which is the paper introducing the framework:
@misc{rigas2024adaptivetraininggriddependentphysicsinformed,
title={Adaptive Training of Grid-Dependent Physics-Informed Kolmogorov-Arnold Networks},
author={Spyros Rigas and Michalis Papachristou and Theofilos Papadopoulos and Fotios Anagnostopoulos and Georgios Alexandridis},
year={2024},
eprint={2407.17611},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.17611},
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jaxkan-0.1.6.tar.gz
.
File metadata
- Download URL: jaxkan-0.1.6.tar.gz
- Upload date:
- Size: 15.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 557240f50f96787baa477fd796951208edff912afd515d00171f30d383888176 |
|
MD5 | fa2ffae051d970d41392b83e87fbd4a9 |
|
BLAKE2b-256 | 713b9e108f492f3edaecb00facc726e6ace125a2870c4577d1a6da54efb9e2a9 |
File details
Details for the file jaxkan-0.1.6-py3-none-any.whl
.
File metadata
- Download URL: jaxkan-0.1.6-py3-none-any.whl
- Upload date:
- Size: 20.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a6ce47eb34fff4712a3d26292595bda88c997245a5257b495ee4ade946bd6ebd |
|
MD5 | 30c302757043ddc7138a9f291fcabb2b |
|
BLAKE2b-256 | b6f265dfca15f8880dffab522c63082bcbea71745ee20514460b2dd8dec9a69e |