A unified deep learning model building library for PyTorch and JAX with shared dataclass-based configuration
Project description
ml-networks
JP | EN
A deep learning model building library supporting both PyTorch and JAX (Flax NNX). Build models across frameworks using a unified Config system.
Features
- Dual Framework: PyTorch / JAX (Flax NNX) with a shared Config system
- Vision Models: Encoder, Decoder, ConvNet, ResNet (PixelShuffle/Unshuffle), Vision Transformer
- Generative Models: Conditional UNet (1D/2D) for Diffusion Models
- Distributions: Normal, Categorical, Bernoulli, BSQ Codebook
- Loss Functions: Focal Loss, Charbonnier Loss, Focal Frequency Loss, KL Divergence
- Advanced: HyperNetwork, Contrastive Learning, SpatialSoftmax
- Utilities: Custom activations, optimizers, blosc2 I/O, seed control
Installation
# PyTorch modules only
pip install ml-networks
# With JAX support (Python 3.11+ required)
pip install "ml-networks[jax]"
uv / rye
# uv
uv add ml-networks
uv add "ml-networks[jax]" # with JAX
# rye
rye add ml-networks
rye add "ml-networks[jax]" # with JAX
Development version (from GitHub)
pip install git+https://github.com/keio-crl/ml-networks.git
Requirements
| PyTorch backend | JAX backend | |
|---|---|---|
| Python | >= 3.10 | >= 3.11 |
| Framework | PyTorch >= 2.0 | JAX >= 0.4.30, Flax >= 0.10.0 |
Quick Start
MLP
from ml_networks.torch import MLPLayer
from ml_networks import MLPConfig, LinearConfig
cfg = MLPConfig(
hidden_dim=128,
n_layers=2,
output_activation="Tanh",
linear_cfg=LinearConfig(activation="ReLU", bias=True),
)
mlp = MLPLayer(input_dim=16, output_dim=8, cfg=cfg)
y = mlp(torch.randn(32, 16)) # (32, 8)
Encoder
from ml_networks.torch import Encoder
from ml_networks import ConvNetConfig, ConvConfig, LinearConfig
backbone_cfg = ConvNetConfig(
channels=[16, 32, 64],
conv_cfgs=[
ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
],
)
fc_cfg = LinearConfig(activation="ReLU", bias=True)
encoder = Encoder(feature_dim=64, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)
z = encoder(torch.randn(32, 3, 64, 64)) # (32, 64)
Backbone options
from ml_networks import ResNetConfig, ViTConfig, TransformerConfig
# ResNet + PixelUnshuffle
backbone_cfg = ResNetConfig(
conv_channel=64,
conv_kernel=3,
f_kernel=3,
conv_activation="ReLU",
out_activation="ReLU",
n_res_blocks=3,
scale_factor=2,
n_scaling=3,
norm="batch",
norm_cfg={"affine": True},
)
# Vision Transformer
backbone_cfg = ViTConfig(
patch_size=8,
transformer_cfg=TransformerConfig(
d_model=64,
nhead=8,
dim_ff=256,
n_layers=3,
),
)
FC layer options
from ml_networks import MLPConfig, LinearConfig, SpatialSoftmaxConfig, AdaptiveAveragePoolingConfig
fc_cfg = MLPConfig(
hidden_dim=128,
n_layers=2,
output_activation="Tanh",
linear_cfg=LinearConfig(activation="ReLU", bias=True),
)
fc_cfg = LinearConfig(activation="ReLU", bias=True)
fc_cfg = SpatialSoftmaxConfig(temperature=1.0)
fc_cfg = AdaptiveAveragePoolingConfig()
fc_cfg = None # output feature map directly
Decoder
from ml_networks.torch import Decoder
from ml_networks import ConvNetConfig, ConvConfig, MLPConfig, LinearConfig
backbone_cfg = ConvNetConfig(
channels=[64, 32, 16],
conv_cfgs=[
ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU"),
ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU"),
ConvConfig(kernel_size=4, stride=2, padding=1, activation="Tanh"),
],
)
fc_cfg = MLPConfig(
hidden_dim=256,
n_layers=2,
output_activation="ReLU",
linear_cfg=LinearConfig(activation="ReLU", bias=True),
)
decoder = Decoder(feature_dim=64, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)
img = decoder(torch.randn(32, 64)) # (32, 3, 64, 64)
Conditional UNet
from ml_networks.torch import ConditionalUnet2d, ConditionalUnet1d
from ml_networks import UNetConfig, ConvConfig
cfg = UNetConfig(
channels=[64, 128, 256],
conv_cfg=ConvConfig(kernel_size=3, padding=1, stride=1, activation="ReLU"),
has_attn=True,
nhead=8,
cond_pred_scale=True,
)
# 2D (images)
net = ConditionalUnet2d(feature_dim=32, obs_shape=(3, 64, 64), cfg=cfg)
out = net(torch.randn(2, 3, 64, 64), cond=torch.randn(2, 32)) # (2, 3, 64, 64)
# 1D (sequences)
net = ConditionalUnet1d(feature_dim=32, obs_shape=(8, 128), cfg=cfg)
out = net(torch.randn(2, 8, 128), cond=torch.randn(2, 32)) # (2, 8, 128)
Distributions
from ml_networks.torch import Distribution, Encoder, stack_dist, cat_dist
dist = Distribution(in_dim=64, dist="normal")
encoder = Encoder(feature_dim=128, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)
z = encoder(obs) # (B, 128) — mean & std concatenated
dist_z = dist(z) # NormalStoch(mean, std, stoch)
# Convert to torch.distributions for KLD computation
torch_dist = dist_z.get_distribution(independent=1)
# Stack / concatenate distribution objects
stacked = stack_dist(dist_list, dim=0)
catted = cat_dist(dist_list, dim=-1)
# Save to disk (blosc2 format)
dist_z.save("reports/")
Loss Functions
from ml_networks.torch import focal_loss, binary_focal_loss, FocalFrequencyLoss, charbonnier
# Focal loss (classification)
loss = focal_loss(logits, labels, gamma=2.0)
# Charbonnier loss (image reconstruction)
loss = charbonnier(pred, target, epsilon=1e-3)
# Focal frequency loss (frequency-domain reconstruction)
ffl = FocalFrequencyLoss(loss_weight=1.0, alpha=1.0)
loss = ffl(pred, target)
JAX Backend
Switch frameworks without changing your configs. JAX modules require rngs at initialization.
from ml_networks.jax import MLPLayer, Encoder, Decoder
from ml_networks import MLPConfig, LinearConfig, ConvNetConfig, ConvConfig
from flax import nnx
import jax.numpy as jnp
cfg = MLPConfig(
hidden_dim=128,
n_layers=2,
output_activation="Tanh",
linear_cfg=LinearConfig(activation="ReLU", bias=True),
)
rngs = nnx.Rngs(0)
mlp = MLPLayer(input_dim=16, output_dim=8, cfg=cfg, rngs=rngs)
y = mlp(jnp.ones((32, 16))) # (32, 8)
JAX modules use NHWC format (channels-last). See the JAX Guide for details.
Data I/O (blosc2)
from ml_networks import save_blosc2, load_blosc2
save_blosc2(data, "dataset/image.blosc2")
loaded = load_blosc2("dataset/image.blosc2")
Utilities
from ml_networks.torch import Activation, get_optimizer, torch_fix_seed
from ml_networks import determine_loader
# Custom activations: REReLU, SiGLU, CRReLU, TanhExp, L2Norm
act = Activation("REReLU")
# Optimizer (supports pytorch-optimizer library)
optimizer = get_optimizer(model.parameters(), "Adam", lr=1e-3)
# Reproducibility
torch_fix_seed(42)
loader = determine_loader(dataset, seed=42, batch_size=32, shuffle=True)
Package Structure
ml_networks/
├── config.py # Shared config classes (PyTorch/JAX)
├── utils.py # Shared utilities (blosc2 I/O, conv shape calc)
├── callbacks.py # PyTorch Lightning callbacks
├── torch/ # PyTorch implementation
│ ├── layers.py # MLP, Conv, Attention, Transformer
│ ├── vision.py # Encoder, Decoder, ConvNet, ResNet, ViT
│ ├── unet.py # ConditionalUnet1d/2d
│ ├── distributions.py
│ ├── loss.py
│ ├── activations.py
│ ├── hypernetworks.py
│ ├── contrastive.py
│ └── torch_utils.py
└── jax/ # JAX (Flax NNX) implementation
├── layers.py
├── vision.py
├── unet.py
├── distributions.py
├── loss.py
├── activations.py
├── hypernetworks.py
├── contrastive.py
└── jax_utils.py
Development
git clone https://github.com/keio-crl/ml-networks.git
cd ml-networks
pip install -e ".[dev]"
# Quality checks
ruff check . # Lint
ruff format . # Format
mypy src/ # Type check
pre-commit run --all-files # All checks
Versioning
Semantic Versioning (MAJOR.MINOR.PATCH).
python scripts/bump_version.py patch # 0.1.0 -> 0.1.1
python scripts/bump_version.py minor # 0.1.0 -> 0.2.0
python scripts/bump_version.py major # 0.1.0 -> 1.0.0
Or use the Version Bump workflow on GitHub Actions.
CI/CD
| Workflow | Trigger | Checks |
|---|---|---|
| CI | Push / PR to main, develop | ruff, mypy, pytest, build |
| Release | Tag push (vX.Y.Z) |
Build + GitHub Release |
| Docs | Push to main | Deploy documentation |
Authors
- oakwood-fujiken (oakwood.n14.4sp@keio.jp)
- nomutin (nomura0508@icloud.com)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ml_networks-0.2.4.tar.gz.
File metadata
- Download URL: ml_networks-0.2.4.tar.gz
- Upload date:
- Size: 317.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e8dc68575e7ca23763d7921f191c937cd3517c0683029be3b7c238535abfd481
|
|
| MD5 |
7331e6094f017c423099d2695e32f9af
|
|
| BLAKE2b-256 |
3d277839f601194f0eecb3b93b0fdb1322065c27917cdf4e04c1889623aa012d
|
Provenance
The following attestation bundles were made for ml_networks-0.2.4.tar.gz:
Publisher:
release.yml on keio-crl/ml-networks
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ml_networks-0.2.4.tar.gz -
Subject digest:
e8dc68575e7ca23763d7921f191c937cd3517c0683029be3b7c238535abfd481 - Sigstore transparency entry: 1245387298
- Sigstore integration time:
-
Permalink:
keio-crl/ml-networks@43f78e8f95d967eb0ba89d5f4861e64e987b1e1c -
Branch / Tag:
refs/tags/v0.2.4 - Owner: https://github.com/keio-crl
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@43f78e8f95d967eb0ba89d5f4861e64e987b1e1c -
Trigger Event:
push
-
Statement type:
File details
Details for the file ml_networks-0.2.4-py3-none-any.whl.
File metadata
- Download URL: ml_networks-0.2.4-py3-none-any.whl
- Upload date:
- Size: 90.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7465bae557383a543623e792015eee9bd5c72233670956d289b1fa4726ee9542
|
|
| MD5 |
bdd737402f69c2d7c49bf6101b654c73
|
|
| BLAKE2b-256 |
c4be8bf6ca8743960e76d8a10197eb94af7bccf5577b685de0eb47ca59b43dd7
|
Provenance
The following attestation bundles were made for ml_networks-0.2.4-py3-none-any.whl:
Publisher:
release.yml on keio-crl/ml-networks
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ml_networks-0.2.4-py3-none-any.whl -
Subject digest:
7465bae557383a543623e792015eee9bd5c72233670956d289b1fa4726ee9542 - Sigstore transparency entry: 1245387301
- Sigstore integration time:
-
Permalink:
keio-crl/ml-networks@43f78e8f95d967eb0ba89d5f4861e64e987b1e1c -
Branch / Tag:
refs/tags/v0.2.4 - Owner: https://github.com/keio-crl
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@43f78e8f95d967eb0ba89d5f4861e64e987b1e1c -
Trigger Event:
push
-
Statement type: