A collection of jax functions to help with common machine/deep learning related functionality.
Project description
jax_toolkit
A collection of jax functions to help with common machine/deep learning related functionality.
This library currently contains the basics for a number of losses and metrics. We intend to add more complexity and functionality as and when it's needed - of course contributions/pull requests/bug reports etc. are very welcome if you discover problems or need something that is currently missing.
Installation
pip install jax_toolkit
Or for additional loss function utils:
pip install jax_toolkit[losses_utils]
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax_toolkit-0.1.2.tar.gz
(12.9 kB
view hashes)
Built Distribution
Close
Hashes for jax_toolkit-0.1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e248add71c42507366d363f6f6ae750851bfe9f3e6174edd83678731a3748843 |
|
MD5 | 84491d259f62ddd4fd9b4ac7ba9f640b |
|
BLAKE2b-256 | 8b8b86dab337d782be3a3f1b679b0d97403b8fad8c19f38c0e5fefef20ba1f3d |