Skip to main content

Foundation models in JAX/Flax

Project description

JaxNN: Foundation Models in JAX/Flax

JaxNN is an open-source library for foundation models in JAX and Flax. It provides a unified framework for loading, creating, and using pretrained models (e.g., ResNet, ViT).

Installation

pip install jaxnn

Usage

Image Classification

from urllib.request import urlopen
from PIL import Image
import jax

import jaxnn

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/cats-image/resolve/main/cats_image.jpeg'
))

model = jaxnn.create_model('resnet34.a1_in1k', pretrained=True)
model.eval()

# Get model-specific transforms (normalization, resize)
data_config = jaxnn.data.resolve_model_data_config(model)
transforms = jaxnn.data.create_transform(**data_config, is_training=False)

output = model(jax.numpy.expand_dims(transforms(img), 0))

top5_probabilities, top5_class_indices = jax.lax.top_k(
    jax.nn.softmax(output, axis=-1) * 100, k=5
)

Feature Map Extraction

from urllib.request import urlopen
from PIL import Image
import jax

import jaxnn

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/cats-image/resolve/main/cats_image.jpeg'
))

model = jaxnn.create_model(
    'resnet34.a1_in1k',
    pretrained=True,
    features_only=True,
)
model.eval()

data_config = jaxnn.data.resolve_model_data_config(model)
transforms = jaxnn.data.create_transform(**data_config, is_training=False)

output = model(jax.numpy.expand_dims(transforms(img), 0))

for o in output:
    print(o.shape)
# (1, 112, 112, 64)
# (1, 56, 56, 64)
# (1, 28, 28, 128)
# (1, 14, 14, 256)
# (1, 7, 7, 512)

Image Embeddings

from urllib.request import urlopen
from PIL import Image
import jax

import jaxnn

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/cats-image/resolve/main/cats_image.jpeg'
))

model = jaxnn.create_model(
    'resnet34.a1_in1k',
    pretrained=True,
    num_classes=0,  # remove classifier
)
model.eval()

data_config = jaxnn.data.resolve_model_data_config(model)
transforms = jaxnn.data.create_transform(**data_config, is_training=False)

output = model(jax.numpy.expand_dims(transforms(img), 0))

# Or use forward methods directly:
output = model.forward_features(jax.numpy.expand_dims(transforms(img), 0))  # (1, 7, 7, 512)
output = model.forward_head(output, pre_logits=True)                        # (1, num_features)

Pretrained models

The data in table for img/sec is given for PyTorch. The output tensor (logits) was compared against the PyTorch original weights.

model img_size top1 top5 param_count gmacs macts img/sec
resnet152d.ra2_in1k 320 83.67 96.74 60.2 24.1 47.7 706
resnet152.a1h_in1k 288 83.46 96.54 60.2 19.1 37.3 904
resnet152d.ra2_in1k 256 83.14 96.38 60.2 15.4 30.5 1096
resnet101d.ra2_in1k 320 83.02 96.45 44.6 16.5 34.8 992
resnet152.a1h_in1k 224 82.8 96.13 60.2 11.6 22.6 1486
resnet101.a1h_in1k 288 82.8 96.32 44.6 13.0 26.8 1291
resnet152.a1_in1k 288 82.74 95.71 60.2 19.1 37.3 905
resnet152.a2_in1k 288 82.62 95.75 60.2 19.1 37.3 904
resnet101.a1_in1k 288 82.31 95.63 44.6 13.0 26.8 1291
resnet152.tv2_in1k 224 82.29 96.0 60.2 11.6 22.6 1484
resnet101d.ra2_in1k 256 82.26 96.07 44.6 10.6 22.2 1542
resnet101.a2_in1k 288 82.24 95.73 44.6 13.0 26.8 1290
resnet152.a1_in1k 224 81.97 95.24 60.2 11.6 22.6 1486
resnet101.a1h_in1k 224 81.93 95.75 44.6 7.8 16.2 2122
resnet101.tv2_in1k 224 81.9 95.77 44.6 7.8 16.2 2118
resnet152.a2_in1k 224 81.77 95.22 60.2 11.6 22.6 1485
resnet101.a1_in1k 224 81.5 95.16 44.6 7.8 16.2 2125
resnet50d.a1_in1k 288 81.44 95.22 25.6 7.2 19.7 1908
resnet50d.ra2_in1k 288 81.37 95.74 25.6 7.2 19.7 1910
resnet101.a2_in1k 224 81.32 95.19 44.6 7.8 16.2 2125
resnet50.a1_in1k 288 81.22 95.11 25.6 6.8 18.4 2089
resnet50_gn.a1h_in1k 288 81.22 95.63 25.6 6.8 18.4 676
resnet50d.a2_in1k 288 81.18 95.09 25.6 7.2 19.7 1908
resnet50.fb_swsl_ig1b_ft_in1k 224 81.18 95.98 25.6 4.1 11.1 3455
resnet152s.gluon_in1k 224 81.02 95.41 60.3 12.9 25.0 1347
resnet50.d_in1k 288 80.97 95.44 25.6 6.8 18.4 2085
resnet50.c1_in1k 288 80.91 95.55 25.6 6.8 18.4 2084
resnet50.c2_in1k 288 80.86 95.52 25.6 6.8 18.4 2085
resnet50.tv2_in1k 224 80.85 95.43 25.6 4.1 11.1 3450
resnet50.a2_in1k 288 80.78 94.99 25.6 6.8 18.4 2088
resnet50.b1k_in1k 288 80.71 95.43 25.6 6.8 18.4 2087
resnet50d.a1_in1k 224 80.68 94.71 25.6 4.4 11.9 3162
resnet152.a3_in1k 224 80.56 95.0 60.2 11.6 22.6 1483
resnet50d.ra2_in1k 224 80.53 95.16 25.6 4.4 11.9 3164
resnet152d.gluon_in1k 224 80.47 95.2 60.2 11.8 23.4 1428
resnet50.b2k_in1k 288 80.45 95.32 25.6 6.8 18.4 2086
resnet101d.gluon_in1k 224 80.42 95.01 44.6 8.1 17.0 2007
resnet50.a1_in1k 224 80.38 94.6 25.6 4.1 11.1 3461
resnet101s.gluon_in1k 224 80.28 95.16 44.7 9.2 18.6 1851
resnet50d.a2_in1k 224 80.22 94.63 25.6 4.4 11.9 3162
resnet152.tv2_in1k 176 80.2 94.64 60.2 7.2 14.0 2346
resnet50_gn.a1h_in1k 224 80.06 94.95 25.6 4.1 11.1 1109
resnet50.ram_in1k 288 79.97 95.05 25.6 6.8 18.4 2086
resnet152c.gluon_in1k 224 79.92 94.84 60.2 11.8 23.4 1455
resnet50.d_in1k 224 79.91 94.67 25.6 4.1 11.1 3456
resnet101.tv2_in1k 176 79.9 94.6 44.6 4.9 10.1 3341
resnet50.c2_in1k 224 79.88 94.87 25.6 4.1 11.1 3455
resnet50.a2_in1k 224 79.85 94.56 25.6 4.1 11.1 3460
resnet50.ra_in1k 288 79.83 94.97 25.6 6.8 18.4 2087
resnet101.a3_in1k 224 79.82 94.62 44.6 7.8 16.2 2114
resnet50.c1_in1k 224 79.74 94.95 25.6 4.1 11.1 3455
resnet152.gluon_in1k 224 79.68 94.74 60.2 11.6 22.6 1486
resnet50.bt_in1k 288 79.63 94.91 25.6 6.8 18.4 2086
resnet101c.gluon_in1k 224 79.53 94.58 44.6 8.1 17.0 2062
resnet50.b1k_in1k 224 79.52 94.61 25.6 4.1 11.1 3459
resnet50.tv2_in1k 176 79.42 94.64 25.6 2.6 6.9 5397
resnet50.b2k_in1k 224 79.38 94.57 25.6 4.1 11.1 3459
resnet101.gluon_in1k 224 79.31 94.53 44.6 7.8 16.2 2125
resnet50.fb_ssl_yfcc100m_ft_in1k 224 79.22 94.84 25.6 4.1 11.1 3451
resnet50d.gluon_in1k 224 79.07 94.48 25.6 4.4 11.9 3162
resnet50.ram_in1k 224 79.03 94.38 25.6 4.1 11.1 3453
resnet50.am_in1k 224 79.01 94.39 25.6 4.1 11.1 3461
resnet152.a3_in1k 160 78.89 94.11 60.2 5.9 11.5 2745
resnet50.ra_in1k 224 78.81 94.32 25.6 4.1 11.1 3454
resnet50s.gluon_in1k 224 78.72 94.23 25.7 5.5 13.5 2796
resnet50d.a3_in1k 224 78.71 94.24 25.6 4.4 11.9 3154
resnet50.bt_in1k 224 78.46 94.27 25.6 4.1 11.1 3454
resnet34d.ra2_in1k 288 78.43 94.35 21.8 6.5 7.5 3291
resnet26t.ra2_in1k 320 78.33 94.13 16.0 5.2 16.4 2391
resnet152.tv_in1k 224 78.32 94.04 60.2 11.6 22.6 1487
resnet50.a3_in1k 224 78.06 93.78 25.6 4.1 11.1 3450
resnet50c.gluon_in1k 224 78.0 93.99 25.6 4.4 11.9 3286
resnet34.a1_in1k 288 77.92 93.77 21.8 6.1 6.2 3609
resnet101.a3_in1k 160 77.88 93.71 44.6 4.0 8.3 3926
resnet26t.ra2_in1k 256 77.87 93.84 16.0 3.4 10.5 3772
resnet50.gluon_in1k 224 77.58 93.72 25.6 4.1 11.1 3455
resnet26d.bt_in1k 288 77.41 93.63 16.0 4.3 13.5 2907
resnet101.tv_in1k 224 77.38 93.54 44.6 7.8 16.2 2125
resnet50d.a3_in1k 160 77.22 93.27 25.6 2.2 6.1 5982
resnet34.a2_in1k 288 77.15 93.27 21.8 6.1 6.2 3615
resnet34d.ra2_in1k 224 77.1 93.37 21.8 3.9 4.5 5436
resnet26d.bt_in1k 224 76.7 93.17 16.0 2.6 8.2 4859
resnet34.bt_in1k 288 76.5 93.35 21.8 6.1 6.2 3617
resnet34.a1_in1k 224 76.42 92.87 21.8 3.7 3.7 5984
resnet26.bt_in1k 288 76.35 93.18 16.0 3.9 12.2 3331
resnet50.tv_in1k 224 76.13 92.86 25.6 4.1 11.1 3457
resnet50.a3_in1k 160 75.96 92.5 25.6 2.1 5.7 6490
resnet34.a2_in1k 224 75.52 92.44 21.8 3.7 3.7 5991
resnet26.bt_in1k 224 75.3 92.58 16.0 2.4 7.4 5583
resnet34.bt_in1k 224 75.16 92.18 21.8 3.7 3.7 5994
resnet34.gluon_in1k 224 74.57 91.98 21.8 3.7 3.7 5984
resnet18d.ra2_in1k 288 73.81 91.83 11.7 3.4 5.4 5196
resnet34.tv_in1k 224 73.32 91.42 21.8 3.7 3.7 5979
resnet18.fb_swsl_ig1b_ft_in1k 224 73.28 91.73 11.7 1.8 2.5 10213
resnet18.a1_in1k 288 73.16 91.03 11.7 3.0 4.1 6050
resnet34.a3_in1k 224 72.98 91.11 21.8 3.7 3.7 5967
resnet18.fb_ssl_yfcc100m_ft_in1k 224 72.6 91.42 11.7 1.8 2.5 10213
resnet18.a2_in1k 288 72.37 90.59 11.7 3.0 4.1 6051
resnet14t.c3_in1k 224 72.26 90.31 10.1 1.7 5.8 7026
resnet18d.ra2_in1k 224 72.26 90.68 11.7 2.1 3.3 8707
resnet18.a1_in1k 224 71.49 90.07 11.7 1.8 2.5 10187
resnet14t.c3_in1k 176 71.31 89.69 10.1 1.1 3.6 10970
resnet18.gluon_in1k 224 70.84 89.76 11.7 1.8 2.5 10210
resnet18.a2_in1k 224 70.64 89.47 11.7 1.8 2.5 10194
resnet34.a3_in1k 160 70.56 89.52 21.8 1.9 1.9 10737
resnet18.tv_in1k 224 69.76 89.07 11.7 1.8 2.5 10205
resnet18.a3_in1k 224 68.25 88.17 11.7 1.8 2.5 10167
resnet18.a3_in1k 160 65.66 86.26 11.7 0.9 1.3 18229

Roadmap

Component Status
ViT, MobileNet, and more
Training/eval loop with optax
Documentation

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxnn-0.1.22.tar.gz (55.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxnn-0.1.22-py3-none-any.whl (56.5 kB view details)

Uploaded Python 3

File details

Details for the file jaxnn-0.1.22.tar.gz.

File metadata

  • Download URL: jaxnn-0.1.22.tar.gz
  • Upload date:
  • Size: 55.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxnn-0.1.22.tar.gz
Algorithm Hash digest
SHA256 b2e57f942269716db6df23c18c10be7773d318a3e2e11f2a6e25451adcdc8973
MD5 60eee4100d815910f7b38f01cab40997
BLAKE2b-256 2d66ffa52da4d779a573a41bd4f458e04207e0dff8a94624446c6e25013ba98f

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxnn-0.1.22.tar.gz:

Publisher: publish.yml on Xrenya/jaxnn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxnn-0.1.22-py3-none-any.whl.

File metadata

  • Download URL: jaxnn-0.1.22-py3-none-any.whl
  • Upload date:
  • Size: 56.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxnn-0.1.22-py3-none-any.whl
Algorithm Hash digest
SHA256 86440147450dc39fe5c282450c2ee1869750aec4728ac02471fd95a2f93c1f31
MD5 6684a82089471982e6ef8a5b22a39342
BLAKE2b-256 5e62a33eda4508d6019a10583310c642c479dbaae038a1f95014cef86a002f95

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxnn-0.1.22-py3-none-any.whl:

Publisher: publish.yml on Xrenya/jaxnn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page