Skip to main content

A unified deep learning model building library for PyTorch and JAX with shared dataclass-based configuration

Project description

ml-networks

JP | EN

PyPI CI Ruff mypy Python 3.10+ License

A deep learning model building library supporting both PyTorch and JAX (Flax NNX). Build models across frameworks using a unified Config system.

Documentation | JAX Guide


Features

  • Dual Framework: PyTorch / JAX (Flax NNX) with a shared Config system
  • Vision Models: Encoder, Decoder, ConvNet, ResNet (PixelShuffle/Unshuffle), Vision Transformer
  • Generative Models: Conditional UNet (1D/2D) for Diffusion Models
  • Distributions: Normal, Categorical, Bernoulli, BSQ Codebook
  • Loss Functions: Focal Loss, Charbonnier Loss, Focal Frequency Loss, KL Divergence
  • Advanced: HyperNetwork, Contrastive Learning, SpatialSoftmax
  • Utilities: Custom activations, optimizers, blosc2 I/O, seed control

Installation

# PyTorch modules only
pip install ml-networks

# With JAX support (Python 3.11+ required)
pip install "ml-networks[jax]"
uv / rye
# uv
uv add ml-networks
uv add "ml-networks[jax]"  # with JAX

# rye
rye add ml-networks
rye add "ml-networks[jax]"  # with JAX
Development version (from GitHub)
pip install git+https://github.com/keio-crl/ml-networks.git

Requirements

PyTorch backend JAX backend
Python >= 3.10 >= 3.11
Framework PyTorch >= 2.0 JAX >= 0.4.30, Flax >= 0.10.0

Quick Start

MLP

from ml_networks.torch import MLPLayer
from ml_networks import MLPConfig, LinearConfig

cfg = MLPConfig(
    hidden_dim=128,
    n_layers=2,
    output_activation="Tanh",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)

mlp = MLPLayer(input_dim=16, output_dim=8, cfg=cfg)
y = mlp(torch.randn(32, 16))  # (32, 8)

Encoder

from ml_networks.torch import Encoder
from ml_networks import ConvNetConfig, ConvConfig, LinearConfig

backbone_cfg = ConvNetConfig(
    channels=[16, 32, 64],
    conv_cfgs=[
        ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
    ],
)
fc_cfg = LinearConfig(activation="ReLU", bias=True)

encoder = Encoder(feature_dim=64, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)
z = encoder(torch.randn(32, 3, 64, 64))  # (32, 64)
Backbone options
from ml_networks import ResNetConfig, ViTConfig, TransformerConfig

# ResNet + PixelUnshuffle
backbone_cfg = ResNetConfig(
    conv_channel=64,
    conv_kernel=3,
    f_kernel=3,
    conv_activation="ReLU",
    out_activation="ReLU",
    n_res_blocks=3,
    scale_factor=2,
    n_scaling=3,
    norm="batch",
    norm_cfg={"affine": True},
)

# Vision Transformer
backbone_cfg = ViTConfig(
    patch_size=8,
    transformer_cfg=TransformerConfig(
        d_model=64,
        nhead=8,
        dim_ff=256,
        n_layers=3,
    ),
)
FC layer options
from ml_networks import MLPConfig, LinearConfig, SpatialSoftmaxConfig, AdaptiveAveragePoolingConfig

fc_cfg = MLPConfig(
    hidden_dim=128,
    n_layers=2,
    output_activation="Tanh",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)
fc_cfg = LinearConfig(activation="ReLU", bias=True)
fc_cfg = SpatialSoftmaxConfig(temperature=1.0)
fc_cfg = AdaptiveAveragePoolingConfig()
fc_cfg = None  # output feature map directly

Decoder

from ml_networks.torch import Decoder
from ml_networks import ConvNetConfig, ConvConfig, MLPConfig, LinearConfig

backbone_cfg = ConvNetConfig(
    channels=[64, 32, 16],
    conv_cfgs=[
        ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=4, stride=2, padding=1, activation="Tanh"),
    ],
)
fc_cfg = MLPConfig(
    hidden_dim=256,
    n_layers=2,
    output_activation="ReLU",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)

decoder = Decoder(feature_dim=64, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)
img = decoder(torch.randn(32, 64))  # (32, 3, 64, 64)

Conditional UNet

from ml_networks.torch import ConditionalUnet2d, ConditionalUnet1d
from ml_networks import UNetConfig, ConvConfig

cfg = UNetConfig(
    channels=[64, 128, 256],
    conv_cfg=ConvConfig(kernel_size=3, padding=1, stride=1, activation="ReLU"),
    has_attn=True,
    nhead=8,
    cond_pred_scale=True,
)

# 2D (images)
net = ConditionalUnet2d(feature_dim=32, obs_shape=(3, 64, 64), cfg=cfg)
out = net(torch.randn(2, 3, 64, 64), cond=torch.randn(2, 32))  # (2, 3, 64, 64)

# 1D (sequences)
net = ConditionalUnet1d(feature_dim=32, obs_shape=(8, 128), cfg=cfg)
out = net(torch.randn(2, 8, 128), cond=torch.randn(2, 32))  # (2, 8, 128)

Distributions

from ml_networks.torch import Distribution, Encoder, stack_dist, cat_dist

dist = Distribution(in_dim=64, dist="normal")
encoder = Encoder(feature_dim=128, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)

z = encoder(obs)  # (B, 128) — mean & std concatenated
dist_z = dist(z)  # NormalStoch(mean, std, stoch)

# Convert to torch.distributions for KLD computation
torch_dist = dist_z.get_distribution(independent=1)

# Stack / concatenate distribution objects
stacked = stack_dist(dist_list, dim=0)
catted = cat_dist(dist_list, dim=-1)

# Save to disk (blosc2 format)
dist_z.save("reports/")

Loss Functions

from ml_networks.torch import focal_loss, binary_focal_loss, FocalFrequencyLoss, charbonnier

# Focal loss (classification)
loss = focal_loss(logits, labels, gamma=2.0)

# Charbonnier loss (image reconstruction)
loss = charbonnier(pred, target, epsilon=1e-3)

# Focal frequency loss (frequency-domain reconstruction)
ffl = FocalFrequencyLoss(loss_weight=1.0, alpha=1.0)
loss = ffl(pred, target)

JAX Backend

Switch frameworks without changing your configs. JAX modules require rngs at initialization.

from ml_networks.jax import MLPLayer, Encoder, Decoder
from ml_networks import MLPConfig, LinearConfig, ConvNetConfig, ConvConfig
from flax import nnx
import jax.numpy as jnp

cfg = MLPConfig(
    hidden_dim=128,
    n_layers=2,
    output_activation="Tanh",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)

rngs = nnx.Rngs(0)
mlp = MLPLayer(input_dim=16, output_dim=8, cfg=cfg, rngs=rngs)
y = mlp(jnp.ones((32, 16)))  # (32, 8)

JAX modules use NHWC format (channels-last). See the JAX Guide for details.

Data I/O (blosc2)

from ml_networks import save_blosc2, load_blosc2

save_blosc2(data, "dataset/image.blosc2")
loaded = load_blosc2("dataset/image.blosc2")

Utilities

from ml_networks.torch import Activation, get_optimizer, torch_fix_seed
from ml_networks import determine_loader

# Custom activations: REReLU, SiGLU, CRReLU, TanhExp, L2Norm
act = Activation("REReLU")

# Optimizer (supports pytorch-optimizer library)
optimizer = get_optimizer(model.parameters(), "Adam", lr=1e-3)

# Reproducibility
torch_fix_seed(42)
loader = determine_loader(dataset, seed=42, batch_size=32, shuffle=True)

Package Structure

ml_networks/
├── config.py          # Shared config classes (PyTorch/JAX)
├── utils.py           # Shared utilities (blosc2 I/O, conv shape calc)
├── callbacks.py       # PyTorch Lightning callbacks
├── torch/             # PyTorch implementation
│   ├── layers.py      # MLP, Conv, Attention, Transformer
│   ├── vision.py      # Encoder, Decoder, ConvNet, ResNet, ViT
│   ├── unet.py        # ConditionalUnet1d/2d
│   ├── distributions.py
│   ├── loss.py
│   ├── activations.py
│   ├── hypernetworks.py
│   ├── contrastive.py
│   └── torch_utils.py
└── jax/               # JAX (Flax NNX) implementation
    ├── layers.py
    ├── vision.py
    ├── unet.py
    ├── distributions.py
    ├── loss.py
    ├── activations.py
    ├── hypernetworks.py
    ├── contrastive.py
    └── jax_utils.py

Development

git clone https://github.com/keio-crl/ml-networks.git
cd ml-networks
pip install -e ".[dev]"

# Quality checks
ruff check .          # Lint
ruff format .         # Format
mypy src/             # Type check
pre-commit run --all-files  # All checks

Versioning

Semantic Versioning (MAJOR.MINOR.PATCH).

python scripts/bump_version.py patch   # 0.1.0 -> 0.1.1
python scripts/bump_version.py minor   # 0.1.0 -> 0.2.0
python scripts/bump_version.py major   # 0.1.0 -> 1.0.0

Or use the Version Bump workflow on GitHub Actions.

CI/CD

Workflow Trigger Checks
CI Push / PR to main, develop ruff, mypy, pytest, build
Release Tag push (vX.Y.Z) Build + GitHub Release
Docs Push to main Deploy documentation

Authors

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ml_networks-0.2.2.tar.gz (315.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ml_networks-0.2.2-py3-none-any.whl (88.8 kB view details)

Uploaded Python 3

File details

Details for the file ml_networks-0.2.2.tar.gz.

File metadata

  • Download URL: ml_networks-0.2.2.tar.gz
  • Upload date:
  • Size: 315.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for ml_networks-0.2.2.tar.gz
Algorithm Hash digest
SHA256 c94efdce82b8cbddeedc86ea0d5a132f5acccda35c9f8b326f248dc41407192e
MD5 2449ac36a11c72c9f7297c98f2406a44
BLAKE2b-256 51c3e5e97d29acde42faf2a5fa6af8f8823801431b1e0037e020670e5662639a

See more details on using hashes here.

Provenance

The following attestation bundles were made for ml_networks-0.2.2.tar.gz:

Publisher: release.yml on keio-crl/ml-networks

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file ml_networks-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: ml_networks-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 88.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for ml_networks-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 d37e39e9ae67a2299254df9df399a72f827a69e1bddebccfbdea1428066bc5f0
MD5 ac6b86a57738ffc8709ca2d9beb49284
BLAKE2b-256 328cca5e5795540f27503a34299f147c9a333557baa8dd6045d367c0a0f98ea5

See more details on using hashes here.

Provenance

The following attestation bundles were made for ml_networks-0.2.2-py3-none-any.whl:

Publisher: release.yml on keio-crl/ml-networks

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page