Skip to main content

One-stop JAX foundation model repository

Project description

foundax

foundax logo

A unified JAX model zoo for neural operators, PINNs, and foundation models.

pip install foundax

Usage

import foundax as fx

# Neural operator architectures (returns initialized FlaxModel)
model = fx.mlp(in_features=2, output_dim=1, hidden_dims=64, num_layers=3)
model = fx.fno2d(in_features=1, hidden_channels=32, n_modes=16)
model = fx.unet2d(in_channels=1, out_channels=1)
model = fx.transformer(num_tokens=1000, d_model=128, num_heads=8)

# Poseidon — (B, 128, 128, C) → (B, 128, 128, C)
model = fx.poseidon.T()   # T/B/L variants
model = fx.poseidon.B()
model = fx.poseidon.L()

# MORPH — (B, t, F, C, D, H, W) → (B, F, C, D, H, W)
model = fx.morph.Ti()    # Ti/S/M/L variants
model = fx.morph.S()
model = fx.morph.M()
model = fx.morph.L()

# MPP — (T, B, C, H, W) → (B, C, H, W)
model = fx.mpp.Ti(n_states=12)  # Ti/S/B/L variants
model = fx.mpp.S(n_states=12)
model = fx.mpp.B(n_states=12)
model = fx.mpp.L(n_states=12)

# Walrus — (B, T, H, W, C) → (B, T, H, W, C_out)
model = fx.walrus.base()

# BCAT — (B, T_in+T_out, 128, 128, C) → (B, T_out, 128, 128, C)
model = fx.bcat.base()

# PDEformer-2 — graph inputs → (n_graph, n_points, 1)
model = fx.pdeformer2.small()  # small/base/fast variants
model = fx.pdeformer2.base()
model = fx.pdeformer2.fast()

# DPOT — (B, 128, 128, T, C) → (B, 128, 128, T_out, C)
model = fx.dpot.Ti()    # Ti/S/M/L/H variants
model = fx.dpot.S()
model = fx.dpot.M()
model = fx.dpot.L()
model = fx.dpot.H()

# PROSE — various sequence-to-sequence configurations
model, variables = fx.prose.fd_1to1()                          # (B, T_in, 128, 128, C) → (B, T_out, 128, 128, C)
model, variables = fx.prose.fd_2to1(n_words=64)                # (B, T_in, 128, 128, C) + symbols → predictions
model, variables = fx.prose.ode_2to1(n_words=64, pad_index=0)  # (T, 1, C) + text → (T_out, C)
model, variables = fx.prose.pde_2to1(n_words=64, pad_index=0)  # (T, 1, C) + text → spatial field

Integration with jNO

import foundax as fx
import jno

net = jno.nn.wrap(fx.mlp(in_features=2, output_dim=1))
net.optimizer(optax.adam, lr=1e-3)

License

This project is licensed under the MIT License.

Foundation models are subject to their original licenses. See THIRD_PARTY_LICENSES for details. Note that some pretrained weights (e.g. Poseidon) are released under non-commercial licenses.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

foundax-0.1.3.tar.gz (156.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

foundax-0.1.3-py3-none-any.whl (195.9 kB view details)

Uploaded Python 3

File details

Details for the file foundax-0.1.3.tar.gz.

File metadata

  • Download URL: foundax-0.1.3.tar.gz
  • Upload date:
  • Size: 156.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for foundax-0.1.3.tar.gz
Algorithm Hash digest
SHA256 d29a650d75dad2f980a75761db97cdca3a7ccc74275913ba7caca14684e17e04
MD5 3225707e783bb2efac56b6f5219923ea
BLAKE2b-256 fabed2f2d3f61cbee4092e5912ec72cb3755abe2f60cfe7f53477d93b44ccd2d

See more details on using hashes here.

File details

Details for the file foundax-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: foundax-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 195.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for foundax-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 56aad4097e1f4d3cd3c966450022290354b2eaa383b63568cbd142b866a9ae02
MD5 938029522aa149ea4111111a2e8fecd6
BLAKE2b-256 a0c4ac92757d74eb921b45612a47d130282c8549d32dfb2d77c726bb07ee2651

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page