Skip to main content

No project description provided

Project description

JAX Metrics

A Metrics library for the JAX ecosystem

Main Features

  • Standard framework-independent metrics that can be used in any JAX project.
  • Pytree-based abstractions that can natively integrate with all JAX APIs.
  • Distributed-friendly APIs that make it super easy to synchronize metrics across devices.
  • Automatic accumulation over entire epochs.

JAX Metrics is implemented on top of Treeo.

What is included?

  • A Keras-like Metric abstraction.
  • A Keras-like Loss abstraction.
  • A Metrics, Losses, and LossesAndMetrics combinators.
  • A metrics moduel containing popular metrics.
  • A losses and regularizers module containing popular losses.

Installation

Install using pip:

pip install jax_metrics

Getting Started

import jax_metrics as jm

metric = jm.metrics.Accuracy()

# Initialize the metric
metric = metric.reset()

# Update the metric with a batch of predictions and labels
metric = metric.update(target=y, preds=logits)

# Get the current value of the metric
acc = metric.compute() # 0.95

# alternatively, produce a logs dict
logs = metric.compute_logs() # {'accuracy': 0.95}
import jax_metrics as jm

metric = jm.metrics.Accuracy()

@jax.jit
def init_step(metric: jm.Metric) -> jm.Metric:
    return metric.reset()


def loss_fn(params, metric, x, y):
    ...
    metric = metric.update(target=y, preds=logits)
    ...

    return loss, metric

@jax.jit
def train_step(params, metric, x, y):
    grads, metric = jax.grad(loss_fn, has_aux=True)(
        params, metric, x, y
    )
    ...
    return params, metric
def loss_fn(params, metric, x, y):
    ...
    # compuate batch update
    batch_updates = metric.batch_updates(target=y, preds=logits)
    # gather over all devices and aggregate
    batch_updates = jax.lax.all_gather(batch_updates, "device").aggregate()
    # update metric
    metric = metric.merge(batch_updates)
    ...
batch_updates = jax.lax.psum(batch_updates, "device")
metrics = jm.Metrics([
    jm.metrics.Accuracy(),
    jm.metrics.F1(), # not yet implemented 😅, coming soon?
])

# same API
metrics = metrics.reset()
# same API
metrics = metrics.update(target=y, preds=logits)
# compute new returns a dict
metrics.compute() # {'accuracy': 0.95, 'f1': 0.87}
# same as compute_logs in the case
metrics.compute_logs() # {'accuracy': 0.95, 'f1': 0.87}
metrics = jm.Metrics({
    "acc": jm.metrics.Accuracy(),
    "f_one": jm.metrics.F1(), # not yet implemented 😅, coming soon?
})

# same API
metrics = metrics.reset()
# same API
metrics = metrics.update(target=y, preds=logits)
# compute new returns a dict
metrics.compute() # {'acc': 0.95, 'f_one': 0.87}
# same as compute_logs in the case
metrics.compute_logs() # {'acc': 0.95, 'f_one': 0.87}
losses = jm.Losses([
    jm.losses.Crossentropy(),
    jm.regularizers.L2(1e-4),
])

# same API
losses = losses.reset()
# same API
losses = losses.update(target=y, preds=logits, parameters=params)
# compute new returns a dict
losses.compute() # {'crossentropy': 0.23, 'l2': 0.005}
# same as compute_logs in the case
losses.compute_logs() # {'crossentropy': 0.23, 'l2': 0.005}
# you can also compute the total loss
total_loss = losses.total_loss() # 0.235
losses = jm.Losses({
    "xent": jm.losses.Crossentropy(),
    "l_two": jm.regularizers.L2(1e-4),
})

# same API
losses = losses.reset()
# same API
losses = losses.update(target=y, preds=logits, parameters=params)
# compute new returns a dict
losses.compute() # {'xent': 0.23, 'l_two': 0.005}
# same as compute_logs in the case
losses.compute_logs() # {'xent': 0.23, 'l_two': 0.005}
# you can also compute the total loss
total_loss = losses.total_loss() # 0.235
def loss_fn(...):
    ...
    batch_updates = losses.loss_and_update(target=y, preds=logits, parameters=params)
    loss = batch_updates.total_loss()
    losses = losses.merge(batch_updates)
    ...
    return loss, losses
def loss_fn(...):
    ...
    loss, lossses = losses.loss_and_update(target=y, preds=logits, parameters=params)
    ...
    return loss, losses
lms = jm.LossesAndMetrics(
    metrics=[
        jm.metrics.Accuracy(),
        jm.metrics.F1(), # not yet implemented 😅, coming soon?
    ],
    losses=[
        jm.losses.Crossentropy(),
        jm.regularizers.L2(1e-4),
    ],
)

# same API
lms = lms.reset()
# same API
lms = lms.update(target=y, preds=logits, parameters=params)
# compute new returns a dict
lms.compute() # {'accuracy': 0.95, 'f1': 0.87, 'crossentropy': 0.23, 'l2': 0.005}
# same as compute_logs in the case
lms.compute_logs() # {'accuracy': 0.95, 'f1': 0.87, 'crossentropy': 0.23, 'l2': 0.005}
# you can also compute the total loss
total_loss = lms.total_loss() # 0.235
def loss_fn(...):
    ...
    batch_updates = lms.batch_updates(target=y, preds=logits, parameters=params)
    loss = batch_updates.total_loss()
    lms = lms.merge(batch_updates)
    ...
    return loss, lms
def loss_fn(...):
    ...
    loss, lms = lms.loss_and_update(target=y, preds=logits, parameters=params)
    ...
    return loss, lms
def loss_fn(...):
    ...
    batch_updates = lms.batch_updates(target=y, preds=logits, parameters=params)
    loss = batch_updates.total_loss()
    batch_updates = jax.lax.all_gather(batch_updates, "device").aggregate()
    lms = lms.merge(batch_updates)
    ...
    return loss, lms

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_metrics-0.1.0a0.tar.gz (53.5 kB view details)

Uploaded Source

Built Distribution

jax_metrics-0.1.0a0-py3-none-any.whl (77.3 kB view details)

Uploaded Python 3

File details

Details for the file jax_metrics-0.1.0a0.tar.gz.

File metadata

  • Download URL: jax_metrics-0.1.0a0.tar.gz
  • Upload date:
  • Size: 53.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.13 CPython/3.8.11 Linux/5.15.11-200.fc35.x86_64

File hashes

Hashes for jax_metrics-0.1.0a0.tar.gz
Algorithm Hash digest
SHA256 464c9c4154a9e0f7a2e626b785d57fb9366a5907f9bcaed15754f71f3b10b30e
MD5 bafa65d234d28cdeaf608183902ecbc9
BLAKE2b-256 f60fb3c10d3cf40727e844383826f536387bb1771abfc08a54d1d855e3fbea11

See more details on using hashes here.

File details

Details for the file jax_metrics-0.1.0a0-py3-none-any.whl.

File metadata

  • Download URL: jax_metrics-0.1.0a0-py3-none-any.whl
  • Upload date:
  • Size: 77.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.13 CPython/3.8.11 Linux/5.15.11-200.fc35.x86_64

File hashes

Hashes for jax_metrics-0.1.0a0-py3-none-any.whl
Algorithm Hash digest
SHA256 38920dbcffef4d892878c420ab4801202798e9966c197dea4486975fa8f98a22
MD5 c674506c39bdc97e91500e8a6d84da92
BLAKE2b-256 f47e109af1a52bc3665455dc0b508e9e80024640b9372518b535d0231d5bb1dc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page