No project description provided
Project description
JAX Metrics
A Metrics library for the JAX ecosystem
Main Features
- Standard framework-independent metrics that can be used in any JAX project.
- Pytree-based abstractions that can natively integrate with all JAX APIs.
- Distributed-friendly APIs that make it super easy to synchronize metrics across devices.
- Automatic accumulation over entire epochs.
JAX Metrics is implemented on top of Treeo.
What is included?
- A Keras-like
Metric
abstraction. - A Keras-like
Loss
abstraction. - A
Metrics
,Losses
, andLossesAndMetrics
combinators. - A
metrics
moduel containing popular metrics. - A
losses
andregularizers
module containing popular losses.
Installation
Install using pip:
pip install jax_metrics
Getting Started
import jax_metrics as jm
metric = jm.metrics.Accuracy()
# Initialize the metric
metric = metric.reset()
# Update the metric with a batch of predictions and labels
metric = metric.update(target=y, preds=logits)
# Get the current value of the metric
acc = metric.compute() # 0.95
# alternatively, produce a logs dict
logs = metric.compute_logs() # {'accuracy': 0.95}
import jax_metrics as jm
metric = jm.metrics.Accuracy()
@jax.jit
def init_step(metric: jm.Metric) -> jm.Metric:
return metric.reset()
def loss_fn(params, metric, x, y):
...
metric = metric.update(target=y, preds=logits)
...
return loss, metric
@jax.jit
def train_step(params, metric, x, y):
grads, metric = jax.grad(loss_fn, has_aux=True)(
params, metric, x, y
)
...
return params, metric
def loss_fn(params, metric, x, y):
...
# compuate batch update
batch_updates = metric.batch_updates(target=y, preds=logits)
# gather over all devices and aggregate
batch_updates = jax.lax.all_gather(batch_updates, "device").aggregate()
# update metric
metric = metric.merge(batch_updates)
...
batch_updates = jax.lax.psum(batch_updates, "device")
metrics = jm.Metrics([
jm.metrics.Accuracy(),
jm.metrics.F1(), # not yet implemented 😅, coming soon?
])
# same API
metrics = metrics.reset()
# same API
metrics = metrics.update(target=y, preds=logits)
# compute new returns a dict
metrics.compute() # {'accuracy': 0.95, 'f1': 0.87}
# same as compute_logs in the case
metrics.compute_logs() # {'accuracy': 0.95, 'f1': 0.87}
metrics = jm.Metrics({
"acc": jm.metrics.Accuracy(),
"f_one": jm.metrics.F1(), # not yet implemented 😅, coming soon?
})
# same API
metrics = metrics.reset()
# same API
metrics = metrics.update(target=y, preds=logits)
# compute new returns a dict
metrics.compute() # {'acc': 0.95, 'f_one': 0.87}
# same as compute_logs in the case
metrics.compute_logs() # {'acc': 0.95, 'f_one': 0.87}
losses = jm.Losses([
jm.losses.Crossentropy(),
jm.regularizers.L2(1e-4),
])
# same API
losses = losses.reset()
# same API
losses = losses.update(target=y, preds=logits, parameters=params)
# compute new returns a dict
losses.compute() # {'crossentropy': 0.23, 'l2': 0.005}
# same as compute_logs in the case
losses.compute_logs() # {'crossentropy': 0.23, 'l2': 0.005}
# you can also compute the total loss
total_loss = losses.total_loss() # 0.235
losses = jm.Losses({
"xent": jm.losses.Crossentropy(),
"l_two": jm.regularizers.L2(1e-4),
})
# same API
losses = losses.reset()
# same API
losses = losses.update(target=y, preds=logits, parameters=params)
# compute new returns a dict
losses.compute() # {'xent': 0.23, 'l_two': 0.005}
# same as compute_logs in the case
losses.compute_logs() # {'xent': 0.23, 'l_two': 0.005}
# you can also compute the total loss
total_loss = losses.total_loss() # 0.235
def loss_fn(...):
...
batch_updates = losses.loss_and_update(target=y, preds=logits, parameters=params)
loss = batch_updates.total_loss()
losses = losses.merge(batch_updates)
...
return loss, losses
def loss_fn(...):
...
loss, lossses = losses.loss_and_update(target=y, preds=logits, parameters=params)
...
return loss, losses
lms = jm.LossesAndMetrics(
metrics=[
jm.metrics.Accuracy(),
jm.metrics.F1(), # not yet implemented 😅, coming soon?
],
losses=[
jm.losses.Crossentropy(),
jm.regularizers.L2(1e-4),
],
)
# same API
lms = lms.reset()
# same API
lms = lms.update(target=y, preds=logits, parameters=params)
# compute new returns a dict
lms.compute() # {'accuracy': 0.95, 'f1': 0.87, 'crossentropy': 0.23, 'l2': 0.005}
# same as compute_logs in the case
lms.compute_logs() # {'accuracy': 0.95, 'f1': 0.87, 'crossentropy': 0.23, 'l2': 0.005}
# you can also compute the total loss
total_loss = lms.total_loss() # 0.235
def loss_fn(...):
...
batch_updates = lms.batch_updates(target=y, preds=logits, parameters=params)
loss = batch_updates.total_loss()
lms = lms.merge(batch_updates)
...
return loss, lms
def loss_fn(...):
...
loss, lms = lms.loss_and_update(target=y, preds=logits, parameters=params)
...
return loss, lms
def loss_fn(...):
...
batch_updates = lms.batch_updates(target=y, preds=logits, parameters=params)
loss = batch_updates.total_loss()
batch_updates = jax.lax.all_gather(batch_updates, "device").aggregate()
lms = lms.merge(batch_updates)
...
return loss, lms
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax_metrics-0.1.0a0.tar.gz
(53.5 kB
view details)
Built Distribution
File details
Details for the file jax_metrics-0.1.0a0.tar.gz
.
File metadata
- Download URL: jax_metrics-0.1.0a0.tar.gz
- Upload date:
- Size: 53.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.13 CPython/3.8.11 Linux/5.15.11-200.fc35.x86_64
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 464c9c4154a9e0f7a2e626b785d57fb9366a5907f9bcaed15754f71f3b10b30e |
|
MD5 | bafa65d234d28cdeaf608183902ecbc9 |
|
BLAKE2b-256 | f60fb3c10d3cf40727e844383826f536387bb1771abfc08a54d1d855e3fbea11 |
File details
Details for the file jax_metrics-0.1.0a0-py3-none-any.whl
.
File metadata
- Download URL: jax_metrics-0.1.0a0-py3-none-any.whl
- Upload date:
- Size: 77.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.13 CPython/3.8.11 Linux/5.15.11-200.fc35.x86_64
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 38920dbcffef4d892878c420ab4801202798e9966c197dea4486975fa8f98a22 |
|
MD5 | c674506c39bdc97e91500e8a6d84da92 |
|
BLAKE2b-256 | f47e109af1a52bc3665455dc0b508e9e80024640b9372518b535d0231d5bb1dc |