Skip to main content

A unified deep learning model building library for PyTorch and JAX with shared dataclass-based configuration

Project description

ml-networks

JP | EN

PyPI CI Ruff mypy Python 3.10+ License

A deep learning model building library supporting both PyTorch and JAX (Flax NNX). Build models across frameworks using a unified Config system.

Documentation | JAX Guide


Features

  • Dual Framework: PyTorch / JAX (Flax NNX) with a shared Config system
  • Vision Models: Encoder, Decoder, ConvNet, ResNet (PixelShuffle/Unshuffle), Vision Transformer
  • Generative Models: Conditional UNet (1D/2D) for Diffusion Models
  • Distributions: Normal, Categorical, Bernoulli, BSQ Codebook
  • Loss Functions: Focal Loss, Charbonnier Loss, Focal Frequency Loss, KL Divergence
  • Advanced: HyperNetwork, Contrastive Learning, SpatialSoftmax
  • Utilities: Custom activations, optimizers, blosc2 I/O, seed control

Installation

# PyTorch modules only
pip install ml-networks

# With JAX support (Python 3.11+ required)
pip install "ml-networks[jax]"
uv / rye
# uv
uv add ml-networks
uv add "ml-networks[jax]"  # with JAX

# rye
rye add ml-networks
rye add "ml-networks[jax]"  # with JAX
Development version (from GitHub)
pip install git+https://github.com/keio-crl/ml-networks.git

Requirements

PyTorch backend JAX backend
Python >= 3.10 >= 3.11
Framework PyTorch >= 2.0 JAX >= 0.4.30, Flax >= 0.10.0

Quick Start

MLP

from ml_networks.torch import MLPLayer
from ml_networks import MLPConfig, LinearConfig

cfg = MLPConfig(
    hidden_dim=128,
    n_layers=2,
    output_activation="Tanh",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)

mlp = MLPLayer(input_dim=16, output_dim=8, cfg=cfg)
y = mlp(torch.randn(32, 16))  # (32, 8)

Encoder

from ml_networks.torch import Encoder
from ml_networks import ConvNetConfig, ConvConfig, LinearConfig

backbone_cfg = ConvNetConfig(
    channels=[16, 32, 64],
    conv_cfgs=[
        ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
    ],
)
fc_cfg = LinearConfig(activation="ReLU", bias=True)

encoder = Encoder(feature_dim=64, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)
z = encoder(torch.randn(32, 3, 64, 64))  # (32, 64)
Backbone options
from ml_networks import ResNetConfig, ViTConfig, TransformerConfig

# ResNet + PixelUnshuffle
backbone_cfg = ResNetConfig(
    conv_channel=64,
    conv_kernel=3,
    f_kernel=3,
    conv_activation="ReLU",
    out_activation="ReLU",
    n_res_blocks=3,
    scale_factor=2,
    n_scaling=3,
    norm="batch",
    norm_cfg={"affine": True},
)

# Vision Transformer
backbone_cfg = ViTConfig(
    patch_size=8,
    transformer_cfg=TransformerConfig(
        d_model=64,
        nhead=8,
        dim_ff=256,
        n_layers=3,
    ),
)
FC layer options
from ml_networks import MLPConfig, LinearConfig, SpatialSoftmaxConfig, AdaptiveAveragePoolingConfig

fc_cfg = MLPConfig(
    hidden_dim=128,
    n_layers=2,
    output_activation="Tanh",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)
fc_cfg = LinearConfig(activation="ReLU", bias=True)
fc_cfg = SpatialSoftmaxConfig(temperature=1.0)
fc_cfg = AdaptiveAveragePoolingConfig()
fc_cfg = None  # output feature map directly

Decoder

from ml_networks.torch import Decoder
from ml_networks import ConvNetConfig, ConvConfig, MLPConfig, LinearConfig

backbone_cfg = ConvNetConfig(
    channels=[64, 32, 16],
    conv_cfgs=[
        ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=4, stride=2, padding=1, activation="Tanh"),
    ],
)
fc_cfg = MLPConfig(
    hidden_dim=256,
    n_layers=2,
    output_activation="ReLU",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)

decoder = Decoder(feature_dim=64, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)
img = decoder(torch.randn(32, 64))  # (32, 3, 64, 64)

Conditional UNet

from ml_networks.torch import ConditionalUnet2d, ConditionalUnet1d
from ml_networks import UNetConfig, ConvConfig

cfg = UNetConfig(
    channels=[64, 128, 256],
    conv_cfg=ConvConfig(kernel_size=3, padding=1, stride=1, activation="ReLU"),
    has_attn=True,
    nhead=8,
    cond_pred_scale=True,
)

# 2D (images)
net = ConditionalUnet2d(feature_dim=32, obs_shape=(3, 64, 64), cfg=cfg)
out = net(torch.randn(2, 3, 64, 64), cond=torch.randn(2, 32))  # (2, 3, 64, 64)

# 1D (sequences)
net = ConditionalUnet1d(feature_dim=32, obs_shape=(8, 128), cfg=cfg)
out = net(torch.randn(2, 8, 128), cond=torch.randn(2, 32))  # (2, 8, 128)

Distributions

from ml_networks.torch import Distribution, Encoder, stack_dist, cat_dist

dist = Distribution(in_dim=64, dist="normal")
encoder = Encoder(feature_dim=128, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)

z = encoder(obs)  # (B, 128) — mean & std concatenated
dist_z = dist(z)  # NormalStoch(mean, std, stoch)

# Convert to torch.distributions for KLD computation
torch_dist = dist_z.get_distribution(independent=1)

# Stack / concatenate distribution objects
stacked = stack_dist(dist_list, dim=0)
catted = cat_dist(dist_list, dim=-1)

# Save to disk (blosc2 format)
dist_z.save("reports/")

Loss Functions

from ml_networks.torch import focal_loss, binary_focal_loss, FocalFrequencyLoss, charbonnier

# Focal loss (classification)
loss = focal_loss(logits, labels, gamma=2.0)

# Charbonnier loss (image reconstruction)
loss = charbonnier(pred, target, epsilon=1e-3)

# Focal frequency loss (frequency-domain reconstruction)
ffl = FocalFrequencyLoss(loss_weight=1.0, alpha=1.0)
loss = ffl(pred, target)

JAX Backend

Switch frameworks without changing your configs. JAX modules require rngs at initialization.

from ml_networks.jax import MLPLayer, Encoder, Decoder
from ml_networks import MLPConfig, LinearConfig, ConvNetConfig, ConvConfig
from flax import nnx
import jax.numpy as jnp

cfg = MLPConfig(
    hidden_dim=128,
    n_layers=2,
    output_activation="Tanh",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)

rngs = nnx.Rngs(0)
mlp = MLPLayer(input_dim=16, output_dim=8, cfg=cfg, rngs=rngs)
y = mlp(jnp.ones((32, 16)))  # (32, 8)

JAX modules use NHWC format (channels-last). See the JAX Guide for details.

Data I/O (blosc2)

from ml_networks import save_blosc2, load_blosc2

save_blosc2(data, "dataset/image.blosc2")
loaded = load_blosc2("dataset/image.blosc2")

Utilities

from ml_networks.torch import Activation, get_optimizer, torch_fix_seed
from ml_networks import determine_loader

# Custom activations: REReLU, SiGLU, CRReLU, TanhExp, L2Norm
act = Activation("REReLU")

# Optimizer (supports pytorch-optimizer library)
optimizer = get_optimizer(model.parameters(), "Adam", lr=1e-3)

# Reproducibility
torch_fix_seed(42)
loader = determine_loader(dataset, seed=42, batch_size=32, shuffle=True)

Package Structure

ml_networks/
├── config.py          # Shared config classes (PyTorch/JAX)
├── utils.py           # Shared utilities (blosc2 I/O, conv shape calc)
├── callbacks.py       # PyTorch Lightning callbacks
├── torch/             # PyTorch implementation
│   ├── layers.py      # MLP, Conv, Attention, Transformer
│   ├── vision.py      # Encoder, Decoder, ConvNet, ResNet, ViT
│   ├── unet.py        # ConditionalUnet1d/2d
│   ├── distributions.py
│   ├── loss.py
│   ├── activations.py
│   ├── hypernetworks.py
│   ├── contrastive.py
│   └── torch_utils.py
└── jax/               # JAX (Flax NNX) implementation
    ├── layers.py
    ├── vision.py
    ├── unet.py
    ├── distributions.py
    ├── loss.py
    ├── activations.py
    ├── hypernetworks.py
    ├── contrastive.py
    └── jax_utils.py

Development

git clone https://github.com/keio-crl/ml-networks.git
cd ml-networks
pip install -e ".[dev]"

# Quality checks
ruff check .          # Lint
ruff format .         # Format
mypy src/             # Type check
pre-commit run --all-files  # All checks

Versioning

Semantic Versioning (MAJOR.MINOR.PATCH).

python scripts/bump_version.py patch   # 0.1.0 -> 0.1.1
python scripts/bump_version.py minor   # 0.1.0 -> 0.2.0
python scripts/bump_version.py major   # 0.1.0 -> 1.0.0

Or use the Version Bump workflow on GitHub Actions.

CI/CD

Workflow Trigger Checks
CI Push / PR to main, develop ruff, mypy, pytest, build
Release Tag push (vX.Y.Z) Build + GitHub Release
Docs Push to main Deploy documentation

Authors

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ml_networks-0.2.5.tar.gz (321.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ml_networks-0.2.5-py3-none-any.whl (94.1 kB view details)

Uploaded Python 3

File details

Details for the file ml_networks-0.2.5.tar.gz.

File metadata

  • Download URL: ml_networks-0.2.5.tar.gz
  • Upload date:
  • Size: 321.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for ml_networks-0.2.5.tar.gz
Algorithm Hash digest
SHA256 2d20d0d16f4757dc8d6bff313a67a5f56ae27386c8bf77640b48fb1e9af807fd
MD5 e306582ddf36240edf28fdc9c1f22547
BLAKE2b-256 3e855fc7e5531f82fa26eb79c64c846e625115fcc4b9983b5e728b67930e13ce

See more details on using hashes here.

Provenance

The following attestation bundles were made for ml_networks-0.2.5.tar.gz:

Publisher: release.yml on keio-crl/ml-networks

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file ml_networks-0.2.5-py3-none-any.whl.

File metadata

  • Download URL: ml_networks-0.2.5-py3-none-any.whl
  • Upload date:
  • Size: 94.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for ml_networks-0.2.5-py3-none-any.whl
Algorithm Hash digest
SHA256 fefa42b0807cd8ec87fa05c4f4e3bb1899533b54f437b2af99ef76cce2d061f5
MD5 e2cd994cdef45de5cb56a903337a7e30
BLAKE2b-256 9c5dc552eb96521c6ed43a8858a65d403ca1a998ee652611ba23a98fb899ab87

See more details on using hashes here.

Provenance

The following attestation bundles were made for ml_networks-0.2.5-py3-none-any.whl:

Publisher: release.yml on keio-crl/ml-networks

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page