Skip to main content

Foundation models in JAX/Flax

Project description

JaxNN: Foundation Models in JAX/Flax

JaxNN is an open-source library for foundation models in JAX and Flax. It provides a unified framework for loading, creating, and using pretrained models (e.g., ResNet, ViT).

Installation

pip install jaxnn

Usage

Image Classification

from urllib.request import urlopen
from PIL import Image
import jax

import jaxnn

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/cats-image/resolve/main/cats_image.jpeg'
))

model = jaxnn.create_model('resnet34.a1_in1k', pretrained=True)
model.eval()

# Get model-specific transforms (normalization, resize)
data_config = jaxnn.data.resolve_model_data_config(model)
transforms = jaxnn.data.create_transform(**data_config, is_training=False)

output = model(jax.numpy.expand_dims(transforms(img), 0))

top5_probabilities, top5_class_indices = jax.lax.top_k(
    jax.nn.softmax(output, axis=-1) * 100, k=5
)

Feature Map Extraction

from urllib.request import urlopen
from PIL import Image
import jax

import jaxnn

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/cats-image/resolve/main/cats_image.jpeg'
))

model = jaxnn.create_model(
    'resnet34.a1_in1k',
    pretrained=True,
    features_only=True,
)
model.eval()

data_config = jaxnn.data.resolve_model_data_config(model)
transforms = jaxnn.data.create_transform(**data_config, is_training=False)

output = model(jax.numpy.expand_dims(transforms(img), 0))

for o in output:
    print(o.shape)
# (1, 112, 112, 64)
# (1, 56, 56, 64)
# (1, 28, 28, 128)
# (1, 14, 14, 256)
# (1, 7, 7, 512)

Image Embeddings

from urllib.request import urlopen
from PIL import Image
import jax

import jaxnn

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/cats-image/resolve/main/cats_image.jpeg'
))

model = jaxnn.create_model(
    'resnet34.a1_in1k',
    pretrained=True,
    num_classes=0,  # remove classifier
)
model.eval()

data_config = jaxnn.data.resolve_model_data_config(model)
transforms = jaxnn.data.create_transform(**data_config, is_training=False)

output = model(jax.numpy.expand_dims(transforms(img), 0))

# Or use forward methods directly:
output = model.forward_features(jax.numpy.expand_dims(transforms(img), 0))  # (1, 7, 7, 512)
output = model.forward_head(output, pre_logits=True)                        # (1, num_features)

Roadmap

Component Status
ViT, MobileNet, and more
Training/eval loop with optax
Documentation

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxnn-0.1.1.tar.gz (48.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxnn-0.1.1-py3-none-any.whl (51.2 kB view details)

Uploaded Python 3

File details

Details for the file jaxnn-0.1.1.tar.gz.

File metadata

  • Download URL: jaxnn-0.1.1.tar.gz
  • Upload date:
  • Size: 48.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxnn-0.1.1.tar.gz
Algorithm Hash digest
SHA256 79f22bac8b46a86f3dc080942b097f73f444ff1bc7ce84d6613c8c8820aeb679
MD5 e88ce012f1c9dd4335af34b063af8a2e
BLAKE2b-256 206ef9184c3711371fea9bf8d65ef77f1d7f69259ab140dfe62c213abf35b890

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxnn-0.1.1.tar.gz:

Publisher: publish.yml on Xrenya/jaxnn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxnn-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: jaxnn-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 51.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxnn-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d5b2c112a3062e3adf57d96f24eca25dad20705c95938e657e41dfcaa52ad105
MD5 f3f25ee8bb827324f10165bb5591c901
BLAKE2b-256 865a3f7de15c7ffe6920eca66bbf6cde086582ad72a07d1a0427c0355b818b80

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxnn-0.1.1-py3-none-any.whl:

Publisher: publish.yml on Xrenya/jaxnn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page