Skip to main content

Foundation models in JAX/Flax

Project description

JaxNN: Foundation Models in JAX/Flax

JaxNN is an open-source library for foundation models in JAX and Flax. It provides a unified framework for loading, creating, and using pretrained models (e.g., ResNet, ViT).

Note: jaxnn is still in development. Pip installation is not yet available but will be released soon when more models are ported to Flax/JAX.

Installation

pip install jaxnn  # coming soon

Usage

Image Classification

from urllib.request import urlopen
from PIL import Image
import jaxnn
import jax

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = jaxnn.create_model('resnet34.a1_in1k', pretrained=True)
model.eval()

# Get model-specific transforms (normalization, resize)
data_config = jaxnn.data.resolve_model_data_config(model)
transforms = jaxnn.data.create_transform(**data_config, is_training=False)

output = model(jax.numpy.expand_dims(transforms(img), 0))

top5_probabilities, top5_class_indices = jax.lax.top_k(
    jax.nn.softmax(output, axis=-1) * 100, k=5
)

Feature Map Extraction

from urllib.request import urlopen
from PIL import Image
import jaxnn
import jax

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = jaxnn.create_model(
    'resnet34.a1_in1k',
    pretrained=True,
    features_only=True,
)
model.eval()

data_config = jaxnn.data.resolve_model_data_config(model)
transforms = jaxnn.data.create_transform(**data_config, is_training=False)

output = model(jax.numpy.expand_dims(transforms(img), 0))

for o in output:
    print(o.shape)
# (1, 112, 112, 64)
# (1, 56, 56, 64)
# (1, 28, 28, 128)
# (1, 14, 14, 256)
# (1, 7, 7, 512)

Image Embeddings

from urllib.request import urlopen
from PIL import Image
import jaxnn
import jax

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = jaxnn.create_model(
    'resnet34.a1_in1k',
    pretrained=True,
    num_classes=0,  # remove classifier
)
model.eval()

data_config = jaxnn.data.resolve_model_data_config(model)
transforms = jaxnn.data.create_transform(**data_config, is_training=False)

output = model(jax.numpy.expand_dims(transforms(img), 0))

# Or use forward methods directly:
output = model.forward_features(jax.numpy.expand_dims(transforms(img), 0))  # (1, 7, 7, 512)
output = model.forward_head(output, pre_logits=True)                         # (1, num_features)

Roadmap

Component Status
Model registry + factory (create_model)
Pretrained ResNet family
Preprocessing + normalization
Weight loading from Hugging Face Hub
CLI tool (jaxnn list, jaxnn info)
PyPI package
ViT, MobileNet, and more
Training/eval loop with optax
Documentation

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxnn-0.1.0.tar.gz (47.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxnn-0.1.0-py3-none-any.whl (49.2 kB view details)

Uploaded Python 3

File details

Details for the file jaxnn-0.1.0.tar.gz.

File metadata

  • Download URL: jaxnn-0.1.0.tar.gz
  • Upload date:
  • Size: 47.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxnn-0.1.0.tar.gz
Algorithm Hash digest
SHA256 df7b0bd14f66d6f1a38d3a9cb3a0c1dc02f3a2d42c95a19246d303a2a6e9b582
MD5 f8b486ba66c6965c85bc890a71c4e8ff
BLAKE2b-256 6610a9ca9b043e88df831c2a59bbf3d20629d6c82a036e10fdc8738b914dc5db

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxnn-0.1.0.tar.gz:

Publisher: publish.yml on Xrenya/jaxnn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxnn-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jaxnn-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 49.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jaxnn-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bcb14f16aa9fb112fa9f3ad6381a2a384e01e201bbf36b0e7eb077bf697e2893
MD5 80046137effca2a4ee2ae4d75a41c553
BLAKE2b-256 1e64d9091c5528da79ee3c70b530006d80a1a342793a676bbfd74769dca57451

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxnn-0.1.0-py3-none-any.whl:

Publisher: publish.yml on Xrenya/jaxnn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page