Skip to main content

A lightweight machine learning package for computational mechanics.

Project description

logo

A lightweight machine learning package for computational mechanics built on JAX.


Check out the Documentation for examples and reference material.

What is Klax?

Klax provides specialized machine learning architectures, constraints, and training utilities for mechanics and physics applications. Built on top of JAX, Equinox, and Optax, it offers:

  • Special Neural Networks: Implementations of, e.g., Input Convex Neural Networks (ICNNs), matrix-valued neural networks, MLPs with custom initialization, and more.
  • JAX Compatibility: Seamless integration with JAX's automatic differentiation and acceleration.
  • Parameter Constraints: Differentiable and non-differentiable parameter constraints through klax.Unwrappable and klax.Constraint
  • Customizable Training: Methods and APIs for customized calibrations on arbitrary PyTree data structures through klax.fit, klax.Loss, and klax.Callback.

Klax is designed to be minimally intrusive - all models inherit directly from equinox.Module without additional abstraction layers. This ensures full compatibility with the JAX/Equinox ecosystem.

The constraint system is derived from Paramax's paramax.AbstractUnwrappable, extending it to support non-differentiable/zero-gradient parameter constraints such as ReLU-based non-negativity constraints.

The provided calibration utilities (klax.fit, klax.Loss, klax.Callback) are designed to operate on arbitrarily shaped PyTrees of data, fully utilizing the flexibility of JAX and Equinox. While they cover most common machine learning use cases, as well as our specialized requirements, they remain entirely optional. The core building blocks of Klax work seamlessly in custom training loops.

Currently Klax's training utilities are built around Optax, but different optimization libraries could be supported in the future if desired.

If you like using Klax, feel free to leave a GitHub star, and if there is a machine learning architecture that you think should be included in Klax, please consider making a PR.

Installation

Klax requires python 3.12+.

pip install klax

or get the most recent changes from the main branch via

pip install "klax @ git+https://github.com/Drenderer/klax.git@main"

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

klax-0.1.3.tar.gz (3.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

klax-0.1.3-py3-none-any.whl (43.6 kB view details)

Uploaded Python 3

File details

Details for the file klax-0.1.3.tar.gz.

File metadata

  • Download URL: klax-0.1.3.tar.gz
  • Upload date:
  • Size: 3.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.7.19

File hashes

Hashes for klax-0.1.3.tar.gz
Algorithm Hash digest
SHA256 fb6439e3e0a97d037f33632ff7a266329bb5f4f298d83dcd3376f9b024c8c1d8
MD5 249e2491358ef08e6b7e28b179c15f1d
BLAKE2b-256 84a7e3183aec069f2935afef4108419e1f9d2fd8cc013b81c7763a0a9509723c

See more details on using hashes here.

File details

Details for the file klax-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: klax-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 43.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.7.19

File hashes

Hashes for klax-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 e464a0eaf981013218961fd3cd0313b68b8493597345f4596f20383fad38ce29
MD5 af8b207e6ec7007765a9b77573df16be
BLAKE2b-256 1507e4d52c33778e610b2fb1fee3f4bb5d69a815a5fc178ae5cba58d76d60d15

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page