Skip to main content

A PyTorch adapter for forward-only model definition

Project description

blaze logo

Blaze: Name less. Build more.

PyPI

A PyTorch adapter inspired by Haiku's functional programming model. Write stateless forward functions using inline layer calls — no nn.Module boilerplate — and get back a proper nn.Module with full parameter management and torch.jit.script support.

✨ Why blaze?

Traditional way to define PyTorch models makes you write every layer twice — declared in __init__, used in forward — and requires naming each one (arch nemesis of programmers), drastically slowing down development. bl removes all of that: layers are written once, inline, exactly where they're used.

Example: Convolutional network

Traditional PyTorch:

class ConvNet(nn.Module):
    def __init__(self):          # ← boilerplate you must write
        super().__init__()       # ← boilerplate you must write
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)  # ← named here...
        self.bn1   = nn.BatchNorm2d(32)               # ← named here...
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)  # ← named here...
        self.bn2   = nn.BatchNorm2d(64)               # ← named here...
        self.pool  = nn.AdaptiveAvgPool2d(1)          # ← named here...
        self.fc    = nn.Linear(64, 10)                # ← named here...

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # ← ...and used here
        x = F.relu(self.bn2(self.conv2(x)))  # ← ...and used here
        x = self.pool(x).flatten(1)          # ← ...and used here
        return self.fc(x)                    # ← ...and used here, what's the output dim again?

model = ConvNet()

Blaze:

# No class. No __init__. No self. No invented names. Only logic.
def forward(x):
    x = F.relu(bl.BatchNorm2d(32)(bl.Conv2d(3, 32, 3, padding=1)(x)))
    x = F.relu(bl.BatchNorm2d(64)(bl.Conv2d(32, 64, 3, padding=1)(x)))
    x = bl.AdaptiveAvgPool2d(1)(x).flatten(1)
    return bl.Linear(64, 10)(x)

model = bl.transform(forward)
model.init(torch.randn(1, 3, 32, 32)) # discovers and creates all modules

🗑️ What gets eliminated

PyTorch requirement With blaze
class MyModel(nn.Module) Plain function or thin bl.Module subclass
def __init__(self) Not needed
super().__init__() Not needed
self.layer = nn.Linear(...) Not needed — layers are created inline
Inventing a name for every layer Auto-derived from class name and deduplicated
nn.ModuleList / nn.ModuleDict for dynamic structure A plain Python loop or dict
Passing hyperparameters through __init__ to store for forward Just use them directly in the function

🚀 Features

  • 🧹 No nn.Module boilerplate — define models as plain functions; layers are called inline.
  • 🔌 Drop-in compatibleBlazeModule is a standard nn.Module; training loops, optimizers, state_dict, and deployment code need no changes.
  • ⚙️ Automatic parameter management — weights are discovered on the first init() pass and reused on every subsequent forward call.
  • 🗂️ Hierarchical naming — module paths (e.g. "block/linear", "block/linear_1") are derived automatically from class names and deduplicated per scope.
  • 🧩 Composable modules — subclass bl.Module to build reusable components; scopes nest correctly no matter how deep.
  • 🎛️ Raw parametersget_parameter() creates a learnable nn.Parameter scoped to the current path, without any surrounding module.
  • 💾 Non-trainable stateget_state() creates a buffer tensor (analogous to hk.get_state) that is tracked by the module but excluded from gradient updates.
  • 🏷️ Custom names — any layer call accepts a name= keyword to override the auto-derived registry key.
  • torch.jit.script support — after init(), models can be scripted for deployment with no extra steps.
  • 🔄 Train/eval propagation.train() / .eval() propagate correctly to all registered sub-modules.
  • 🧱 Built-in layer wrappers — covers linear, conv, norm, pooling, activation, dropout, recurrent, embedding, attention, transformer, and shape layers.

📦 Installation

pip install blaze-torch

🧑‍💻 Quickstart

import torch
import blaze as bl

def forward(x):
    x = bl.Linear(10, 64)(x)
    x = bl.ReLU()(x)
    x = bl.Linear(64, 1)(x)
    return x

model = bl.transform(forward)
model.init(torch.randn(4, 10))   # discovers and creates all modules

out = model(torch.randn(4, 10))  # normal nn.Module usage

📖 Core concepts

🔁 Two-phase execution

bl.transform wraps your function. Calling .init(sample_input) runs an INIT pass that discovers every layer call and registers it into an internal registry keyed by its hierarchical path (e.g. "block/linear"). Subsequent calls run in APPLY mode, reusing the registered modules by call order.

model = bl.transform(forward)
model.init(torch.randn(batch, in_features))  # INIT pass — creates weights
output = model(x)                            # APPLY pass — reuses weights

🧩 User-defined modules (bl.Module)

Subclass bl.Module and implement __call__ to group layers into reusable components. The class name is automatically converted to snake_case for scoping, and repeated instantiations are deduplicated with a numeric suffix.

class MLP(bl.Module):
    def __call__(self, x):
        x = bl.Linear(x.shape[-1], 128)(x)
        x = bl.GELU()(x)
        x = bl.Linear(128, x.shape[-1])(x)
        return x

def forward(x):
    x = MLP()(x)   # parameter names: "mlp/linear", "mlp/gelu", "mlp/linear_1"
    x = MLP()(x)   # parameter names: "mlp_1/linear", ...
    return x

🎛️ Raw parameters (bl.get_parameter)

Create a learnable nn.Parameter directly, scoped to the current name context. Analogous to hk.get_parameter.

def forward(x):
    scale = bl.get_parameter("scale", (x.shape[-1],), init_fn=torch.ones)
    bias  = bl.get_parameter("bias",  (x.shape[-1],), init_fn=torch.zeros)
    return x * scale + bias

💾 Non-trainable state (bl.get_state)

Create a buffer tensor (non-trainable, tracked by the module). Analogous to hk.get_state.

def forward(x):
    running_mean = bl.get_state("running_mean", (x.shape[-1],), init_fn=torch.zeros)
    return x - running_mean

🏷️ Custom names

Pass name= to any layer call to override the auto-derived name:

def forward(x):
    x = bl.Linear(10, 64, name="encoder")(x)
    x = bl.Linear(64, 10, name="decoder")(x)
    return x

⚡ TorchScript / JIT

After .init(), a model can be scripted with torch.jit.script:

model = bl.transform(forward)
model.init(torch.randn(2, 10))

scripted = torch.jit.script(model)
out = scripted(torch.randn(2, 10))

🧱 Available layers

All wrappers accept the same arguments as their torch.nn counterparts.

Category Layers
Linear Linear, Bilinear
Conv Conv1d/2d/3d, ConvTranspose1d/2d/3d
Norm BatchNorm1d/2d/3d, SyncBatchNorm, InstanceNorm1d/2d/3d, LayerNorm, GroupNorm, RMSNorm
Pooling MaxPool1d/2d/3d, AvgPool1d/2d/3d, AdaptiveAvgPool1d/2d/3d, AdaptiveMaxPool1d/2d/3d
Activation ReLU, ReLU6, LeakyReLU, PReLU, ELU, SELU, CELU, GELU, Mish, SiLU, Tanh, Sigmoid, Hardsigmoid, Hardswish, Softmax, LogSoftmax, Softplus
Dropout Dropout, Dropout1d/2d/3d, AlphaDropout
Recurrent LSTM, GRU, RNN, LSTMCell, GRUCell, RNNCell
Embedding Embedding, EmbeddingBag
Attention MultiheadAttention
Transformer Transformer, TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer
Shape Flatten, Unflatten, Upsample, PixelShuffle, PixelUnshuffle
Misc Identity

🏋️ Training

BlazeModule is a standard nn.Module — use any PyTorch optimizer:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for x, y in dataloader:
    optimizer.zero_grad()
    loss = criterion(model(x), y)
    loss.backward()
    optimizer.step()

🔗 Related projects

Project Framework Description
dm-haiku JAX The original inspiration. Transforms stateful hk.Module code into pure (init, apply) function pairs via hk.transform.
Flax NNX JAX Google's neural network library for JAX. The newer NNX API uses PyTorch-style __init__/__call__ with mutable state; the older Linen API is closer to Haiku's functional style.
Equinox JAX Neural networks as callable PyTrees. Models are plain Python dataclasses; parameters live in the tree rather than a separate registry, making them compatible with jax.jit/jax.grad directly.
torch.func PyTorch PyTorch's built-in functional transforms (formerly functorch). torch.func.functional_call lets you call an existing nn.Module with an explicit parameter dict, enabling per-sample gradients, meta-learning, etc.
PyTorch Lightning PyTorch Training loop abstraction over nn.Module. Reduces boilerplate around the train/val/test cycle but keeps the imperative nn.Module programming model.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blaze_torch-0.0.3.tar.gz (17.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

blaze_torch-0.0.3-py3-none-any.whl (13.2 kB view details)

Uploaded Python 3

File details

Details for the file blaze_torch-0.0.3.tar.gz.

File metadata

  • Download URL: blaze_torch-0.0.3.tar.gz
  • Upload date:
  • Size: 17.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blaze_torch-0.0.3.tar.gz
Algorithm Hash digest
SHA256 8d6ac9812e1051f7aca3035cb6eddbf468a78d9cd2319b57e4ef9c3617638675
MD5 2c6ffd5863a3399f6a0c81082abc18b6
BLAKE2b-256 16d43056ed36d4047bc03c5cb1668daaec2e576c46e7d20a247efed1b52e2d8a

See more details on using hashes here.

Provenance

The following attestation bundles were made for blaze_torch-0.0.3.tar.gz:

Publisher: publish.yml on baosws/blaze

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file blaze_torch-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: blaze_torch-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 13.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blaze_torch-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 d056b4b0b9eb16c42d6ce8b75c28de7bf23e1ba1613101061fdd8ef823dbe663
MD5 6a9bd99b3c17b88d0b0fff1d8cf69e61
BLAKE2b-256 5515c97cc8c0c79434c81a5fb42c3160de78a3def0edad7fed8a08dfb599b286

See more details on using hashes here.

Provenance

The following attestation bundles were made for blaze_torch-0.0.3-py3-none-any.whl:

Publisher: publish.yml on baosws/blaze

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page