A PyTorch adapter for forward-only model definition
Project description
Blaze: Write less. Build more.
A PyTorch adapter inspired by Haiku's functional programming model. Write forward-only models using inline layer calls โ no nn.Module boilerplate โ and get back a proper nn.Module with full parameter management and other goodies.
Table of Contents
- Table of Contents
- โจ Why blaze?
- ๐ Features
- ๐ฆ Installation
- ๐งโ๐ป Quickstart
- ๐ Core concepts
- ๐ Two-phase execution
- ๐๏ธ Training
- ๐งฉ User-defined modules (
bl.Module) - ๐๏ธ Raw parameters (
bl.get_parameter) - ๐พ Non-trainable state (
bl.get_state/bl.set_state) - ๐ Wrapping existing modules (
bl.wrap) - ๐ท๏ธ Custom names
- ๐๏ธ Custom initialization (
init_fn) - โก Compilation
- ๐งฑ Available layers
- ๐ Related projects
- ๐ License
โจ Why blaze?
Traditional way to define PyTorch models makes you write every layer twice โ declared in __init__, used in forward โ and requires naming each one (arch nemesis of programmers), drastically slowing down iterative prototyping. You also have to manually track and hardcode input sizes for every layer, even when they could be trivially computed from the previous layer's output. bl removes all of that: layers are written once, inline, exactly where they're used, and input sizes can be read straight from the live tensor via x.shape during init().
Example: Convolutional network
Traditional PyTorch:
class ConvNet(nn.Module):
def __init__(self): # โ boilerplate you must write
super().__init__() # โ boilerplate you must write
self.conv1 = nn.Conv2d(3, 32, 3, padding=1) # โ named here...
self.bn1 = nn.BatchNorm2d(32) # โ named here...
self.conv2 = nn.Conv2d(32, 64, 3, padding=1) # โ named here...
self.bn2 = nn.BatchNorm2d(64) # โ named here...
self.pool = nn.AdaptiveAvgPool2d(1) # โ named here...
self.fc = nn.Linear(64, 10) # โ named here & must know input size!
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x))) # โ ...and used here
x = F.relu(self.bn2(self.conv2(x))) # โ ...and used here
x = self.pool(x).flatten(1) # โ ...and used here
return self.fc(x) # โ ...and used here, what's the output dim again?
model = ConvNet()
Blaze:
# No class. No __init__. No self. No invented names. Only logic.
def forward(x):
x = F.relu(bl.BatchNorm2d(32)(bl.Conv2d(3, 32, 3, padding=1)(x)))
x = F.relu(bl.BatchNorm2d(64)(bl.Conv2d(32, 64, 3, padding=1)(x)))
x = bl.AdaptiveAvgPool2d(1)(x).flatten(1)
return bl.Linear(x.shape[-1], 10)(x) # โ input size computed from the tensor
model = bl.transform(forward)
model.init(torch.randn(1, 3, 32, 32)) # discovers and creates all modules
๐๏ธ What gets eliminated
| PyTorch requirement | With blaze |
|---|---|
class MyModel(nn.Module) |
Plain function or thin bl.Module subclass |
def __init__(self) |
Not needed |
super().__init__() |
Not needed |
self.layer = nn.Linear(...) |
Not needed โ layers are created inline |
| Inventing a name for every layer | Auto-derived from class name and deduplicated |
nn.ModuleList / nn.ModuleDict for dynamic structure |
A plain Python loop or dict |
Passing hyperparameters through __init__ to store for forward |
Just use them directly in the function |
| Manually tracking input sizes across layers | Use x.shape โ sizes are inferred from the live tensor during init() |
๐ Features
- ๐งน No
nn.Moduleboilerplate โ define models as plain functions; layers are called inline. - ๐ Drop-in compatible โ
BlazeModuleis a standardnn.Module; training loops, optimizers,state_dict, and deployment code need no changes. - โ๏ธ Automatic parameter management โ weights are discovered on the first
init()pass, reused on every subsequent call, and organized into hierarchical paths (e.g."block.linear") derived automatically from class names (or overridden withname=). - ๐ Dynamic size inference โ since
init()runs with a real tensor, layer sizes can be computed fromx.shapeinstead of hardcoded โ no more manually tracking dimensions across layers. - ๐งฉ Composable modules โ subclass
bl.Moduleto build reusable components; scopes nest correctly no matter how deep. - ๐๏ธ Raw parameters โ
get_parameter()creates a learnablenn.Parameterscoped to the current path, without any surrounding module. - ๐พ Non-trainable state โ
get_state()/set_state()create and update buffer tensors (analogous tohk.get_state/hk.set_state) that are tracked by the module but excluded from gradient updates. - โก
torch.jit.script/traceandtorch.compile(experimental) โ afterinit(), models can be compiled for performance and deployment. - ๐งฑ Built-in layer wrappers โ covers linear, conv, norm, pooling, activation, dropout, recurrent, embedding, attention, transformer, and shape layers.
- ๐ Seamless integration โ
bl.wraplets you use any existingnn.Module, pretrained model, or third-party layer directly inside a blaze function.
๐ฆ Installation
pip install blaze-torch
๐งโ๐ป Quickstart
import torch
import blaze as bl
def forward(x, hidden_dim):
x = bl.Linear(x.shape[-1], hidden_dim)(x)
x = bl.ReLU()(x)
x = bl.Linear(hidden_dim, 1)(x)
return x
model = bl.transform(forward, hidden_dim=128) # pass kwargs as static hyperparameters
model.init(torch.randn(4, 10)) # discovers and creates all modules
out = model(torch.randn(4, 10)) # normal nn.Module usage
๐ Core concepts
๐ Two-phase execution
bl.transform wraps your forward function. Calling .init(sample_input) runs an INIT pass that discovers every layer call and registers it into an internal registry keyed by its hierarchical path (e.g. "block.linear"). Subsequent calls run in APPLY mode, reusing the registered modules by call order.
model = bl.transform(forward) # empty model
model.init(torch.randn(batch, in_features)) # INIT pass โ creates weights
output = model(x) # APPLY pass โ reuses weights
๐๏ธ Training
BlazeModule is a standard nn.Module โ use any PyTorch optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for x, y in dataloader:
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
๐งฉ User-defined modules (bl.Module)
Subclass bl.Module and implement __call__ to group layers into reusable components with same scope. The class name is automatically converted to snake_case for scoping, and repeated instantiations are deduplicated with a numeric suffix.
class MLP(bl.Module):
def __call__(self, x):
x = bl.Linear(x.shape[-1], 128)(x)
x = bl.GELU()(x)
x = bl.Linear(128, x.shape[-1])(x)
return x
def forward(x):
x = MLP()(x) # parameter names: "mlp.linear", "mlp.gelu", "mlp.linear_1"
x = MLP()(x) # parameter names: "mlp_1.linear", ...
return x
๐๏ธ Raw parameters (bl.get_parameter)
Create a learnable nn.Parameter directly, scoped to the current name context. Analogous to hk.get_parameter.
def forward(x):
scale = bl.get_parameter("scale", (x.shape[-1],), init_fn=torch.ones)
bias = bl.get_parameter("bias", (x.shape[-1],), init_fn=torch.zeros)
return x * scale + bias
๐พ Non-trainable state (bl.get_state / bl.set_state)
Create and update buffer tensors (non-trainable, tracked by the module). Analogous to hk.get_state / hk.set_state.
def forward(x):
running_mean = bl.get_state("running_mean", (x.shape[-1],), init_fn=torch.zeros)
bl.set_state("running_mean", running_mean * 0.9 + x.mean(0) * 0.1)
return x - running_mean
๐ Wrapping existing modules (bl.wrap)
Have an existing nn.Module, a pretrained model, or a third-party layer? bl.wrap lets you use it directly inside a blaze function โ no subclassing or redefining needed:
def forward(x):
encoder = bl.wrap(lambda: torchvision.models.resnet18(pretrained=True))
x = encoder(x)
x = bl.Linear(x.shape[-1], 10)(x)
return x
The factory is called once during init() and the resulting module is reused on every subsequent forward call. Pass name= to override the default registry key:
x = bl.wrap(lambda: nn.Linear(10, 64), name="encoder")(x)
๐ท๏ธ Custom names
Pass name= to any layer call to override the auto-derived name:
def forward(x):
x = bl.Linear(10, 64, name="encoder")(x)
x = bl.Linear(64, 10, name="decoder")(x)
return x
๐๏ธ Custom initialization (init_fn)
All layer wrappers and bl.Module subclasses accept an init_fn= callback that runs once during init() and is skipped on subsequent forward calls:
def forward(x):
x = bl.Linear(10, 64, init_fn=lambda m: nn.init.xavier_uniform_(m.weight))(x)
x = bl.Conv2d(64, 32, 3, padding=1, init_fn=lambda m: nn.init.kaiming_normal_(m.weight))(x)
x = bl.BatchNorm2d(32, init_fn=lambda m: nn.init.ones_(m.weight))(x)
return x
Works with bl.Module subclasses too:
class Block(bl.Module):
def __init__(self, dim, init_fn=None):
super().__init__(init_fn=init_fn)
self.dim = dim
def __call__(self, x):
return bl.Linear(x.shape[-1], self.dim)(x)
def forward(x):
return Block(dim=64, init_fn=my_custom_init)(x)
โก Compilation
After .init(), models work with torch.jit.trace, torch.jit.script, and torch.compile:
model = bl.transform(forward)
model.init(torch.randn(2, 10))
traced = torch.jit.trace(model, torch.randn(2, 10))
scripted = torch.jit.script(model)
compiled = torch.compile(model)
๐งฑ Available layers
All wrappers accept the same arguments as their torch.nn counterparts, along with optional arguments name and init_fn.
| Category | Layers |
|---|---|
| Linear | Linear, Bilinear |
| Conv | Conv1d/2d/3d, ConvTranspose1d/2d/3d |
| Norm | BatchNorm1d/2d/3d, SyncBatchNorm, InstanceNorm1d/2d/3d, LayerNorm, GroupNorm, RMSNorm |
| Pooling | MaxPool1d/2d/3d, AvgPool1d/2d/3d, AdaptiveAvgPool1d/2d/3d, AdaptiveMaxPool1d/2d/3d |
| Activation | ReLU, ReLU6, LeakyReLU, PReLU, ELU, SELU, CELU, GELU, Mish, SiLU, Tanh, Sigmoid, Hardsigmoid, Hardswish, Softmax, LogSoftmax, Softplus |
| Dropout | Dropout, Dropout1d/2d/3d, AlphaDropout |
| Recurrent | LSTM, GRU, RNN, LSTMCell, GRUCell, RNNCell |
| Embedding | Embedding, EmbeddingBag |
| Attention | MultiheadAttention |
| Transformer | Transformer, TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer |
| Shape | Flatten, Unflatten, Upsample, PixelShuffle, PixelUnshuffle |
| Misc | Identity |
๐ Related projects
| Project | Framework | Description |
|---|---|---|
| dm-haiku | JAX | The original inspiration. Transforms stateful hk.Module code into pure (init, apply) function pairs via hk.transform. |
| Flax NNX | JAX | Google's neural network library for JAX. The newer NNX API uses PyTorch-style __init__/__call__ with mutable state; the older Linen API is closer to Haiku's functional style. |
| Equinox | JAX | Neural networks as callable PyTrees. Models are plain Python dataclasses; parameters live in the tree rather than a separate registry, making them compatible with jax.jit/jax.grad directly. |
| torch.func | PyTorch | PyTorch's built-in functional transforms (formerly functorch). torch.func.functional_call lets you call an existing nn.Module with an explicit parameter dict, enabling per-sample gradients, meta-learning, etc. |
| PyTorch Lightning | PyTorch | Training loop abstraction over nn.Module. Reduces boilerplate around the train/val/test cycle but keeps the imperative nn.Module programming model. |
๐ License
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file blaze_torch-0.0.41.tar.gz.
File metadata
- Download URL: blaze_torch-0.0.41.tar.gz
- Upload date:
- Size: 23.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1451b980c0e52e73e7f85bc15f2aa8212d7ff4b65f35a964659ce441b04a3c86
|
|
| MD5 |
8dfb2652de1c2f9de3b02e103bb59cef
|
|
| BLAKE2b-256 |
ebad77d8b617598ae6bf197670f38e8092e975f24df35b6e0ce8ee5bf12b3e9b
|
Provenance
The following attestation bundles were made for blaze_torch-0.0.41.tar.gz:
Publisher:
publish.yml on baosws/blaze
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
blaze_torch-0.0.41.tar.gz -
Subject digest:
1451b980c0e52e73e7f85bc15f2aa8212d7ff4b65f35a964659ce441b04a3c86 - Sigstore transparency entry: 975831728
- Sigstore integration time:
-
Permalink:
baosws/blaze@208e20a46e37bf576e2135db9ec5202ba29d0700 -
Branch / Tag:
refs/tags/v0.0.41 - Owner: https://github.com/baosws
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@208e20a46e37bf576e2135db9ec5202ba29d0700 -
Trigger Event:
push
-
Statement type:
File details
Details for the file blaze_torch-0.0.41-py3-none-any.whl.
File metadata
- Download URL: blaze_torch-0.0.41-py3-none-any.whl
- Upload date:
- Size: 14.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1dee708c24308842a85692300522542f4388dba801a8c925069985d6d5be5fd4
|
|
| MD5 |
e126b1ac7ef5060f364aae49fa586bd8
|
|
| BLAKE2b-256 |
f0c338919eb556193ecc1ce500f604d372c7372f135b4ba00074b1a8fc19cbf4
|
Provenance
The following attestation bundles were made for blaze_torch-0.0.41-py3-none-any.whl:
Publisher:
publish.yml on baosws/blaze
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
blaze_torch-0.0.41-py3-none-any.whl -
Subject digest:
1dee708c24308842a85692300522542f4388dba801a8c925069985d6d5be5fd4 - Sigstore transparency entry: 975831729
- Sigstore integration time:
-
Permalink:
baosws/blaze@208e20a46e37bf576e2135db9ec5202ba29d0700 -
Branch / Tag:
refs/tags/v0.0.41 - Owner: https://github.com/baosws
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@208e20a46e37bf576e2135db9ec5202ba29d0700 -
Trigger Event:
push
-
Statement type: