A collection of jax functions to help with common machine/deep learning related functionality.
Project description
jax_toolkit
A collection of jax functions to help with common machine/deep learning related functionality.
This library currently contains the basics for a number of losses and metrics. We intend to add more complexity and functionality as and when it's needed - of course contributions/pull requests/bug reports etc. are very welcome if you discover problems or need something that is currently missing.
Installation
pip install jax_toolkit
Or for additional loss function utils:
pip install jax_toolkit[losses_utils]
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax_toolkit-0.2.0.tar.gz
(13.3 kB
view hashes)
Built Distribution
Close
Hashes for jax_toolkit-0.2.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9beaf04d608bec873e989968510426180f76d819e257bf114a4295ddcd0c9546 |
|
MD5 | cfbf6ed2968fe738d1b1fcffc713b691 |
|
BLAKE2b-256 | 7fc0eb5d3017885aa675817ff74f939af37917e3bb6945a31936e1a3aedfb6e3 |