Skip to main content

A PyTorch adapter for forward-only model definition

Project description

blaze logo

Blaze: Write less. Build more.

PyPI

A PyTorch adapter inspired by Haiku's functional programming model. Write forward-only models using inline layer calls โ€” no nn.Module boilerplate โ€” and get back a proper nn.Module with full parameter management and other goodies.

Table of Contents

โœจ Why blaze?

Traditional way to define PyTorch models makes you write every layer twice โ€” declared in __init__, used in forward โ€” and requires naming each one (arch nemesis of programmers), drastically slowing down iterative prototyping. You also have to manually track and hardcode input sizes for every layer, even when they could be trivially computed from the previous layer's output. bl removes all of that: layers are written once, inline, exactly where they're used, and input sizes can be read straight from the live tensor via x.shape during init().

Example: Convolutional network

Traditional PyTorch:

class ConvNet(nn.Module):
    def __init__(self):          # โ† boilerplate you must write
        super().__init__()       # โ† boilerplate you must write
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)  # โ† named here...
        self.bn1   = nn.BatchNorm2d(32)               # โ† named here...
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)  # โ† named here...
        self.bn2   = nn.BatchNorm2d(64)               # โ† named here...
        self.pool  = nn.AdaptiveAvgPool2d(1)          # โ† named here...
        self.fc    = nn.Linear(64, 10)                # โ† named here & must know input size!

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # โ† ...and used here
        x = F.relu(self.bn2(self.conv2(x)))  # โ† ...and used here
        x = self.pool(x).flatten(1)          # โ† ...and used here
        return self.fc(x)                    # โ† ...and used here, what's the output dim again?

model = ConvNet()

Blaze:

# No class. No __init__. No self. No invented names. Only logic.
def forward(x):
    x = F.relu(bl.BatchNorm2d(32)(bl.Conv2d(3, 32, 3, padding=1)(x)))
    x = F.relu(bl.BatchNorm2d(64)(bl.Conv2d(32, 64, 3, padding=1)(x)))
    x = bl.AdaptiveAvgPool2d(1)(x).flatten(1)
    return bl.Linear(x.shape[-1], 10)(x)  # โ† input size computed from the tensor

model = bl.transform(forward)
model.init(torch.randn(1, 3, 32, 32)) # discovers and creates all modules

๐Ÿ—‘๏ธ What gets eliminated

PyTorch requirement With blaze
class MyModel(nn.Module) Plain function or thin bl.Module subclass
def __init__(self) Not needed
super().__init__() Not needed
self.layer = nn.Linear(...) Not needed โ€” layers are created inline
Inventing a name for every layer Auto-derived from class name and deduplicated
nn.ModuleList / nn.ModuleDict for dynamic structure A plain Python loop or dict
Passing hyperparameters through __init__ to store for forward Just use them directly in the function
Manually tracking input sizes across layers Use x.shape โ€” sizes are inferred from the live tensor during init()

๐Ÿš€ Features

  • ๐Ÿงน No nn.Module boilerplate โ€” define models as plain functions; layers are called inline.
  • ๐Ÿ”Œ Drop-in compatible โ€” BlazeModule is a standard nn.Module; training loops, optimizers, state_dict, and deployment code need no changes.
  • โš™๏ธ Automatic parameter management โ€” weights are discovered on the first init() pass, reused on every subsequent call, and organized into hierarchical paths (e.g. "block.linear") derived automatically from class names (or overridden with name=).
  • ๐Ÿ“ Dynamic size inference โ€” since init() runs with a real tensor, layer sizes can be computed from x.shape instead of hardcoded โ€” no more manually tracking dimensions across layers.
  • ๐Ÿงฉ Composable modules โ€” subclass bl.Module to build reusable components; scopes nest correctly no matter how deep.
  • ๐ŸŽ›๏ธ Raw parameters โ€” get_parameter() creates a learnable nn.Parameter scoped to the current path, without any surrounding module.
  • ๐Ÿ’พ Non-trainable state โ€” get_state() / set_state() create and update buffer tensors (analogous to hk.get_state / hk.set_state) that are tracked by the module but excluded from gradient updates.
  • โšก torch.jit.script/trace and torch.compile (experimental) โ€” after init(), models can be compiled for performance and deployment.
  • ๐Ÿงฑ Built-in layer wrappers โ€” covers linear, conv, norm, pooling, activation, dropout, recurrent, embedding, attention, transformer, and shape layers.
  • ๐Ÿ”Œ Seamless integration โ€” bl.wrap lets you use any existing nn.Module, pretrained model, or third-party layer directly inside a blaze function.

๐Ÿ“ฆ Installation

pip install blaze-torch

๐Ÿง‘โ€๐Ÿ’ป Quickstart

import torch
import blaze as bl

def forward(x, hidden_dim):
    x = bl.Linear(x.shape[-1], hidden_dim)(x)
    x = bl.ReLU()(x)
    x = bl.Linear(hidden_dim, 1)(x)
    return x

model = bl.transform(forward, hidden_dim=128) # pass kwargs as static hyperparameters
model.init(torch.randn(4, 10))   # discovers and creates all modules

out = model(torch.randn(4, 10))  # normal nn.Module usage

๐Ÿ“– Core concepts

๐Ÿ” Two-phase execution

bl.transform wraps your forward function. Calling .init(sample_input) runs an INIT pass that discovers every layer call and registers it into an internal registry keyed by its hierarchical path (e.g. "block.linear"). Subsequent calls run in APPLY mode, reusing the registered modules by call order.

model = bl.transform(forward)                # empty model
model.init(torch.randn(batch, in_features))  # INIT pass โ€” creates weights
output = model(x)                            # APPLY pass โ€” reuses weights

๐Ÿ‹๏ธ Training

BlazeModule is a standard nn.Module โ€” use any PyTorch optimizer:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for x, y in dataloader:
    optimizer.zero_grad()
    loss = criterion(model(x), y)
    loss.backward()
    optimizer.step()

๐Ÿงฉ User-defined modules (bl.Module)

Subclass bl.Module and implement __call__ to group layers into reusable components with same scope. The class name is automatically converted to snake_case for scoping, and repeated instantiations are deduplicated with a numeric suffix.

class MLP(bl.Module):
    def __call__(self, x):
        x = bl.Linear(x.shape[-1], 128)(x)
        x = bl.GELU()(x)
        x = bl.Linear(128, x.shape[-1])(x)
        return x

def forward(x):
    x = MLP()(x)   # parameter names: "mlp.linear", "mlp.gelu", "mlp.linear_1"
    x = MLP()(x)   # parameter names: "mlp_1.linear", ...
    return x

๐ŸŽ›๏ธ Raw parameters (bl.get_parameter)

Create a learnable nn.Parameter directly, scoped to the current name context. Analogous to hk.get_parameter.

def forward(x):
    scale = bl.get_parameter("scale", (x.shape[-1],), init_fn=torch.ones)
    bias  = bl.get_parameter("bias",  (x.shape[-1],), init_fn=torch.zeros)
    return x * scale + bias

๐Ÿ’พ Non-trainable state (bl.get_state / bl.set_state)

Create and update buffer tensors (non-trainable, tracked by the module). Analogous to hk.get_state / hk.set_state.

def forward(x):
    running_mean = bl.get_state("running_mean", (x.shape[-1],), init_fn=torch.zeros)
    bl.set_state("running_mean", running_mean * 0.9 + x.mean(0) * 0.1)
    return x - running_mean

๐Ÿ”Œ Wrapping existing modules (bl.wrap)

Have an existing nn.Module, a pretrained model, or a third-party layer? bl.wrap lets you use it directly inside a blaze function โ€” no subclassing or redefining needed:

def forward(x):
    encoder = bl.wrap(lambda: torchvision.models.resnet18(pretrained=True))
    x = encoder(x)
    x = bl.Linear(x.shape[-1], 10)(x)
    return x

The factory is called once during init() and the resulting module is reused on every subsequent forward call. Pass name= to override the default registry key:

x = bl.wrap(lambda: nn.Linear(10, 64), name="encoder")(x)

๐Ÿท๏ธ Custom names

Pass name= to any layer call to override the auto-derived name:

def forward(x):
    x = bl.Linear(10, 64, name="encoder")(x)
    x = bl.Linear(64, 10, name="decoder")(x)
    return x

๐Ÿ—๏ธ Custom initialization (init_fn)

All layer wrappers and bl.Module subclasses accept an init_fn= callback that runs once during init() and is skipped on subsequent forward calls:

def forward(x):
    x = bl.Linear(10, 64, init_fn=lambda m: nn.init.xavier_uniform_(m.weight))(x)
    x = bl.Conv2d(64, 32, 3, padding=1, init_fn=lambda m: nn.init.kaiming_normal_(m.weight))(x)
    x = bl.BatchNorm2d(32, init_fn=lambda m: nn.init.ones_(m.weight))(x)
    return x

Works with bl.Module subclasses too:

class Block(bl.Module):
    def __init__(self, dim, init_fn=None):
        super().__init__(init_fn=init_fn)
        self.dim = dim

    def __call__(self, x):
        return bl.Linear(x.shape[-1], self.dim)(x)

def forward(x):
    return Block(dim=64, init_fn=my_custom_init)(x)

โšก Compilation

After .init(), models work with torch.jit.trace, torch.jit.script, and torch.compile:

model = bl.transform(forward)
model.init(torch.randn(2, 10))

traced  = torch.jit.trace(model, torch.randn(2, 10))
scripted = torch.jit.script(model)
compiled = torch.compile(model)

๐Ÿงฑ Available layers

All wrappers accept the same arguments as their torch.nn counterparts, along with optional arguments name and init_fn.

Category Layers
Linear Linear, Bilinear
Conv Conv1d/2d/3d, ConvTranspose1d/2d/3d
Norm BatchNorm1d/2d/3d, SyncBatchNorm, InstanceNorm1d/2d/3d, LayerNorm, GroupNorm, RMSNorm
Pooling MaxPool1d/2d/3d, AvgPool1d/2d/3d, AdaptiveAvgPool1d/2d/3d, AdaptiveMaxPool1d/2d/3d
Activation ReLU, ReLU6, LeakyReLU, PReLU, ELU, SELU, CELU, GELU, Mish, SiLU, Tanh, Sigmoid, Hardsigmoid, Hardswish, Softmax, LogSoftmax, Softplus
Dropout Dropout, Dropout1d/2d/3d, AlphaDropout
Recurrent LSTM, GRU, RNN, LSTMCell, GRUCell, RNNCell
Embedding Embedding, EmbeddingBag
Attention MultiheadAttention
Transformer Transformer, TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer
Shape Flatten, Unflatten, Upsample, PixelShuffle, PixelUnshuffle
Misc Identity

๐Ÿ”— Related projects

Project Framework Description
dm-haiku JAX The original inspiration. Transforms stateful hk.Module code into pure (init, apply) function pairs via hk.transform.
Flax NNX JAX Google's neural network library for JAX. The newer NNX API uses PyTorch-style __init__/__call__ with mutable state; the older Linen API is closer to Haiku's functional style.
Equinox JAX Neural networks as callable PyTrees. Models are plain Python dataclasses; parameters live in the tree rather than a separate registry, making them compatible with jax.jit/jax.grad directly.
torch.func PyTorch PyTorch's built-in functional transforms (formerly functorch). torch.func.functional_call lets you call an existing nn.Module with an explicit parameter dict, enabling per-sample gradients, meta-learning, etc.
PyTorch Lightning PyTorch Training loop abstraction over nn.Module. Reduces boilerplate around the train/val/test cycle but keeps the imperative nn.Module programming model.

๐Ÿ“„ License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blaze_torch-0.0.41.tar.gz (23.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

blaze_torch-0.0.41-py3-none-any.whl (14.1 kB view details)

Uploaded Python 3

File details

Details for the file blaze_torch-0.0.41.tar.gz.

File metadata

  • Download URL: blaze_torch-0.0.41.tar.gz
  • Upload date:
  • Size: 23.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blaze_torch-0.0.41.tar.gz
Algorithm Hash digest
SHA256 1451b980c0e52e73e7f85bc15f2aa8212d7ff4b65f35a964659ce441b04a3c86
MD5 8dfb2652de1c2f9de3b02e103bb59cef
BLAKE2b-256 ebad77d8b617598ae6bf197670f38e8092e975f24df35b6e0ce8ee5bf12b3e9b

See more details on using hashes here.

Provenance

The following attestation bundles were made for blaze_torch-0.0.41.tar.gz:

Publisher: publish.yml on baosws/blaze

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file blaze_torch-0.0.41-py3-none-any.whl.

File metadata

  • Download URL: blaze_torch-0.0.41-py3-none-any.whl
  • Upload date:
  • Size: 14.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for blaze_torch-0.0.41-py3-none-any.whl
Algorithm Hash digest
SHA256 1dee708c24308842a85692300522542f4388dba801a8c925069985d6d5be5fd4
MD5 e126b1ac7ef5060f364aae49fa586bd8
BLAKE2b-256 f0c338919eb556193ecc1ce500f604d372c7372f135b4ba00074b1a8fc19cbf4

See more details on using hashes here.

Provenance

The following attestation bundles were made for blaze_torch-0.0.41-py3-none-any.whl:

Publisher: publish.yml on baosws/blaze

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page