Skip to main content

A collection of jax functions to help with common machine/deep learning related functionality.

Project description

jax_toolkit

A collection of jax functions to help with common machine/deep learning related functionality.

Documentation, PyPi

This library currently contains the basics for a number of losses and metrics. We intend to add more complexity and functionality as and when it's needed - of course contributions/pull requests/bug reports etc. are very welcome if you discover problems or need something that is currently missing.

Installation

pip install jax_toolkit

Or for additional loss function utils:

pip install jax_toolkit[losses_utils]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_toolkit-0.2.0.tar.gz (13.3 kB view hashes)

Uploaded Source

Built Distribution

jax_toolkit-0.2.0-py3-none-any.whl (18.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page