Skip to main content

A unified deep learning model building library for PyTorch and JAX with shared dataclass-based configuration

Project description

ml-networks

JP | EN

PyPI CI Ruff mypy Python 3.10+ License

A deep learning model building library supporting both PyTorch and JAX (Flax NNX). Build models across frameworks using a unified Config system.

Documentation | JAX Guide


Features

  • Dual Framework: PyTorch / JAX (Flax NNX) with a shared Config system
  • Vision Models: Encoder, Decoder, ConvNet, ResNet (PixelShuffle/Unshuffle), Vision Transformer
  • Generative Models: Conditional UNet (1D/2D) for Diffusion Models
  • Distributions: Normal, Categorical, Bernoulli, BSQ Codebook
  • Loss Functions: Focal Loss, Charbonnier Loss, Focal Frequency Loss, KL Divergence
  • Advanced: HyperNetwork, Contrastive Learning, SpatialSoftmax
  • Utilities: Custom activations, optimizers, blosc2 I/O, seed control

Installation

# PyTorch modules only
pip install ml-networks

# With JAX support (Python 3.11+ required)
pip install "ml-networks[jax]"
uv / rye
# uv
uv add ml-networks
uv add "ml-networks[jax]"  # with JAX

# rye
rye add ml-networks
rye add "ml-networks[jax]"  # with JAX
Development version (from GitHub)
pip install git+https://github.com/keio-crl/ml-networks.git

Requirements

PyTorch backend JAX backend
Python >= 3.10 >= 3.11
Framework PyTorch >= 2.0 JAX >= 0.4.30, Flax >= 0.10.0

Quick Start

MLP

from ml_networks.torch import MLPLayer
from ml_networks import MLPConfig, LinearConfig

cfg = MLPConfig(
    hidden_dim=128,
    n_layers=2,
    output_activation="Tanh",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)

mlp = MLPLayer(input_dim=16, output_dim=8, cfg=cfg)
y = mlp(torch.randn(32, 16))  # (32, 8)

Encoder

from ml_networks.torch import Encoder
from ml_networks import ConvNetConfig, ConvConfig, LinearConfig

backbone_cfg = ConvNetConfig(
    channels=[16, 32, 64],
    conv_cfgs=[
        ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU"),
    ],
)
fc_cfg = LinearConfig(activation="ReLU", bias=True)

encoder = Encoder(feature_dim=64, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)
z = encoder(torch.randn(32, 3, 64, 64))  # (32, 64)
Backbone options
from ml_networks import ResNetConfig, ViTConfig, TransformerConfig

# ResNet + PixelUnshuffle
backbone_cfg = ResNetConfig(
    conv_channel=64,
    conv_kernel=3,
    f_kernel=3,
    conv_activation="ReLU",
    out_activation="ReLU",
    n_res_blocks=3,
    scale_factor=2,
    n_scaling=3,
    norm="batch",
    norm_cfg={"affine": True},
)

# Vision Transformer
backbone_cfg = ViTConfig(
    patch_size=8,
    transformer_cfg=TransformerConfig(
        d_model=64,
        nhead=8,
        dim_ff=256,
        n_layers=3,
    ),
)
FC layer options
from ml_networks import MLPConfig, LinearConfig, SpatialSoftmaxConfig, AdaptiveAveragePoolingConfig

fc_cfg = MLPConfig(
    hidden_dim=128,
    n_layers=2,
    output_activation="Tanh",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)
fc_cfg = LinearConfig(activation="ReLU", bias=True)
fc_cfg = SpatialSoftmaxConfig(temperature=1.0)
fc_cfg = AdaptiveAveragePoolingConfig()
fc_cfg = None  # output feature map directly

Decoder

from ml_networks.torch import Decoder
from ml_networks import ConvNetConfig, ConvConfig, MLPConfig, LinearConfig

backbone_cfg = ConvNetConfig(
    channels=[64, 32, 16],
    conv_cfgs=[
        ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU"),
        ConvConfig(kernel_size=4, stride=2, padding=1, activation="Tanh"),
    ],
)
fc_cfg = MLPConfig(
    hidden_dim=256,
    n_layers=2,
    output_activation="ReLU",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)

decoder = Decoder(feature_dim=64, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)
img = decoder(torch.randn(32, 64))  # (32, 3, 64, 64)

Conditional UNet

from ml_networks.torch import ConditionalUnet2d, ConditionalUnet1d
from ml_networks import UNetConfig, ConvConfig

cfg = UNetConfig(
    channels=[64, 128, 256],
    conv_cfg=ConvConfig(kernel_size=3, padding=1, stride=1, activation="ReLU"),
    has_attn=True,
    nhead=8,
    cond_pred_scale=True,
)

# 2D (images)
net = ConditionalUnet2d(feature_dim=32, obs_shape=(3, 64, 64), cfg=cfg)
out = net(torch.randn(2, 3, 64, 64), cond=torch.randn(2, 32))  # (2, 3, 64, 64)

# 1D (sequences)
net = ConditionalUnet1d(feature_dim=32, obs_shape=(8, 128), cfg=cfg)
out = net(torch.randn(2, 8, 128), cond=torch.randn(2, 32))  # (2, 8, 128)

Distributions

from ml_networks.torch import Distribution, Encoder, stack_dist, cat_dist

dist = Distribution(in_dim=64, dist="normal")
encoder = Encoder(feature_dim=128, obs_shape=(3, 64, 64), backbone_cfg=backbone_cfg, fc_cfg=fc_cfg)

z = encoder(obs)  # (B, 128) — mean & std concatenated
dist_z = dist(z)  # NormalStoch(mean, std, stoch)

# Convert to torch.distributions for KLD computation
torch_dist = dist_z.get_distribution(independent=1)

# Stack / concatenate distribution objects
stacked = stack_dist(dist_list, dim=0)
catted = cat_dist(dist_list, dim=-1)

# Save to disk (blosc2 format)
dist_z.save("reports/")

Loss Functions

from ml_networks.torch import focal_loss, binary_focal_loss, FocalFrequencyLoss, charbonnier

# Focal loss (classification)
loss = focal_loss(logits, labels, gamma=2.0)

# Charbonnier loss (image reconstruction)
loss = charbonnier(pred, target, epsilon=1e-3)

# Focal frequency loss (frequency-domain reconstruction)
ffl = FocalFrequencyLoss(loss_weight=1.0, alpha=1.0)
loss = ffl(pred, target)

JAX Backend

Switch frameworks without changing your configs. JAX modules require rngs at initialization.

from ml_networks.jax import MLPLayer, Encoder, Decoder
from ml_networks import MLPConfig, LinearConfig, ConvNetConfig, ConvConfig
from flax import nnx
import jax.numpy as jnp

cfg = MLPConfig(
    hidden_dim=128,
    n_layers=2,
    output_activation="Tanh",
    linear_cfg=LinearConfig(activation="ReLU", bias=True),
)

rngs = nnx.Rngs(0)
mlp = MLPLayer(input_dim=16, output_dim=8, cfg=cfg, rngs=rngs)
y = mlp(jnp.ones((32, 16)))  # (32, 8)

JAX modules use NHWC format (channels-last). See the JAX Guide for details.

Data I/O (blosc2)

from ml_networks import save_blosc2, load_blosc2

save_blosc2(data, "dataset/image.blosc2")
loaded = load_blosc2("dataset/image.blosc2")

Utilities

from ml_networks.torch import Activation, get_optimizer, torch_fix_seed
from ml_networks import determine_loader

# Custom activations: REReLU, SiGLU, CRReLU, TanhExp, L2Norm
act = Activation("REReLU")

# Optimizer (supports pytorch-optimizer library)
optimizer = get_optimizer(model.parameters(), "Adam", lr=1e-3)

# Reproducibility
torch_fix_seed(42)
loader = determine_loader(dataset, seed=42, batch_size=32, shuffle=True)

Package Structure

ml_networks/
├── config.py          # Shared config classes (PyTorch/JAX)
├── utils.py           # Shared utilities (blosc2 I/O, conv shape calc)
├── callbacks.py       # PyTorch Lightning callbacks
├── torch/             # PyTorch implementation
│   ├── layers.py      # MLP, Conv, Attention, Transformer
│   ├── vision.py      # Encoder, Decoder, ConvNet, ResNet, ViT
│   ├── unet.py        # ConditionalUnet1d/2d
│   ├── distributions.py
│   ├── loss.py
│   ├── activations.py
│   ├── hypernetworks.py
│   ├── contrastive.py
│   └── torch_utils.py
└── jax/               # JAX (Flax NNX) implementation
    ├── layers.py
    ├── vision.py
    ├── unet.py
    ├── distributions.py
    ├── loss.py
    ├── activations.py
    ├── hypernetworks.py
    ├── contrastive.py
    └── jax_utils.py

Development

git clone https://github.com/keio-crl/ml-networks.git
cd ml-networks
pip install -e ".[dev]"

# Quality checks
ruff check .          # Lint
ruff format .         # Format
mypy src/             # Type check
pre-commit run --all-files  # All checks

Versioning

Semantic Versioning (MAJOR.MINOR.PATCH).

python scripts/bump_version.py patch   # 0.1.0 -> 0.1.1
python scripts/bump_version.py minor   # 0.1.0 -> 0.2.0
python scripts/bump_version.py major   # 0.1.0 -> 1.0.0

Or use the Version Bump workflow on GitHub Actions.

CI/CD

Workflow Trigger Checks
CI Push / PR to main, develop ruff, mypy, pytest, build
Release Tag push (vX.Y.Z) Build + GitHub Release
Docs Push to main Deploy documentation

Authors

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ml_networks-0.1.9.tar.gz (315.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ml_networks-0.1.9-py3-none-any.whl (88.8 kB view details)

Uploaded Python 3

File details

Details for the file ml_networks-0.1.9.tar.gz.

File metadata

  • Download URL: ml_networks-0.1.9.tar.gz
  • Upload date:
  • Size: 315.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for ml_networks-0.1.9.tar.gz
Algorithm Hash digest
SHA256 abe1e654f563c0be9b1518f13e4f727013351315e3feefc3d72a3b8b62e9b1e9
MD5 3b0a9baea5580ea570736a5da9930c32
BLAKE2b-256 4be1089305a47fe54c156e1847b8337addee600eedca032c415c662e96226fca

See more details on using hashes here.

Provenance

The following attestation bundles were made for ml_networks-0.1.9.tar.gz:

Publisher: release.yml on keio-crl/ml-networks

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file ml_networks-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: ml_networks-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 88.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for ml_networks-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 fa62f82a748b58499f85b76ad2961060a0c255ee6e05d7a90ec0e24ab65964eb
MD5 21434ffe9b0af38bafca416a282c1027
BLAKE2b-256 325f1eb8827eb140bb7fb0012d48c5d69835d31325694c4c4b0fb8b772504bde

See more details on using hashes here.

Provenance

The following attestation bundles were made for ml_networks-0.1.9-py3-none-any.whl:

Publisher: release.yml on keio-crl/ml-networks

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page