Flexible residual connections for PyTorch
Project description
torchresidual
Flexible residual connections for PyTorch with a clean, composable API.
Build complex residual architectures without boilerplate. torchresidual provides
Record and Apply modules that let you create skip connections of any depth,
with automatic shape handling and learnable mixing coefficients.
📖 Quick Start | 📚 Full Documentation | 💡 Examples | ❓ FAQ
Why torchresidual?
Standard PyTorch:
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
residual = x
x = self.linear(x)
x = F.relu(x)
x = self.norm(x)
return x + residual # Manual residual
With torchresidual:
block = ResidualSequential(
Record(name="input"),
nn.Linear(64, 64),
nn.ReLU(),
nn.LayerNorm(64),
Apply(record_name="input"), # Automatic residual
)
Benefits:
- No custom
forward()methods - Multiple skip connections with named records
- Automatic projection when dimensions change
- Five residual operations (add, concat, multiply, gated, highway)
- Learnable mixing coefficients
- Works with LSTMs, attention, and any
nn.Module
Installation
pip install torchresidual
Requirements: Python ≥3.8, PyTorch ≥1.9
New to torchresidual? See the Quick Start Guide for a 5-minute tutorial.
Quick Start
Basic residual connection
import torch
import torch.nn as nn
from torchresidual import ResidualSequential, Record, Apply
block = ResidualSequential(
Record(name="input"),
nn.Linear(64, 64),
nn.ReLU(),
nn.LayerNorm(64),
Apply(record_name="input", operation="add"),
)
x = torch.randn(8, 64)
out = block(x) # Shape: [8, 64]
Multiple skip connections
block = ResidualSequential(
Record(name="input", need_projection=True),
nn.Linear(64, 32),
nn.ReLU(),
Record(name="mid"),
nn.Linear(32, 64),
Apply(record_name="input"), # Long skip with projection
nn.LayerNorm(64),
nn.Linear(64, 32),
Apply(record_name="mid"), # Short skip
)
Learnable mixing coefficient
from torchresidual import LearnableAlpha
block = ResidualSequential(
Record(name="r"),
nn.Linear(64, 64),
Apply(
record_name="r",
operation="gated",
alpha=LearnableAlpha(0.3, min_value=0.0, max_value=1.0)
),
)
# Alpha is learned during training
optimizer = torch.optim.Adam(block.parameters(), lr=1e-3)
Automatic projection for shape changes
# Input: [batch, 64] → Output: [batch, 128]
block = ResidualSequential(
Record(name="r", need_projection=True), # Enables auto-projection
nn.Linear(64, 128),
nn.ReLU(),
Apply(record_name="r"), # Automatically projects 64→128
)
LSTM with residual
from torchresidual import RecurrentWrapper
block = ResidualSequential(
Record(name="r"),
RecurrentWrapper(
nn.LSTM(32, 32, num_layers=2, batch_first=True),
return_hidden=False
),
Apply(record_name="r"),
)
x = torch.randn(4, 10, 32) # [batch, seq_len, features]
out = block(x)
API Reference
Core Components
ResidualSequential(*modules)
Drop-in replacement for nn.Sequential with residual connection support.
Example:
block = ResidualSequential(
nn.Linear(64, 64),
Record(),
nn.ReLU(),
Apply(),
)
Record(need_projection=False, name=None)
Saves the current tensor for later use in a residual connection.
Args:
need_projection(bool): IfTrue,Applywill create a linear projection when shapes don't matchname(str, optional): Label for this record point. Auto-assigned ifNone.
Example:
Record(name="input", need_projection=True)
Apply(operation="add", record_name=None, alpha=1.0)
Applies a residual connection using a previously recorded tensor.
Args:
operation(str): One of"add","concat","multiply","gated","highway"record_name(str, optional): WhichRecordto use. IfNone, uses most recent.alpha(float or LearnableAlpha): Scaling factor for residual branch
Operations:
| Operation | Formula | Use case |
|---|---|---|
add |
x + α·r |
Standard ResNet-style |
concat |
cat([x, r], dim=-1) |
DenseNet-style |
multiply |
x·(1 + α·r) |
Multiplicative skip |
gated |
(1-α)·x + α·r |
Learnable interpolation |
highway |
T·x + C·r |
Highway Networks |
Example:
Apply(operation="gated", record_name="input", alpha=0.5)
LearnableAlpha(initial_value, min_value=0.0, max_value=1.0, use_log_space=None)
Learnable scalar parameter constrained to [min_value, max_value].
Args:
initial_value(float): Starting valuemin_value(float): Lower bound (inclusive)max_value(float): Upper bound (inclusive)use_log_space(bool, optional): Force log or linear parameterization. Auto-detected ifNone.
Example:
alpha = LearnableAlpha(0.5, min_value=0.0, max_value=1.0)
x = x + alpha() * residual # alpha() returns constrained value
RecurrentWrapper(module, return_hidden=False)
Wraps LSTM/GRU modules for seamless integration with ResidualSequential.
Args:
module(nn.Module): The recurrent module (e.g.,nn.LSTM)return_hidden(bool): IfTrue, returns(output, hidden)tuple
Example:
RecurrentWrapper(nn.LSTM(64, 64, batch_first=True), return_hidden=False)
Advanced Examples
Transformer-style block
# Multi-head attention with residual and layer norm
block = ResidualSequential(
Record(name="input"),
nn.MultiheadAttention(embed_dim=256, num_heads=8),
Apply(record_name="input"),
nn.LayerNorm(256),
Record(name="attn_out"),
nn.Linear(256, 1024),
nn.ReLU(),
nn.Linear(1024, 256),
Apply(record_name="attn_out"),
nn.LayerNorm(256),
)
Nested residual blocks
inner_block = ResidualSequential(
Record(),
nn.Linear(64, 64),
nn.ReLU(),
Apply(),
)
outer_block = ResidualSequential(
Record(),
inner_block,
nn.Linear(64, 64),
Apply(),
)
Complex encoder block
from collections import OrderedDict
encoder = ResidualSequential(OrderedDict([
('record_input', Record(need_projection=True, name="input")),
('conv1', nn.Conv1d(64, 128, kernel_size=3, padding=1)),
('relu1', nn.ReLU()),
('record_mid', Record(name="mid")),
('conv2', nn.Conv1d(128, 128, kernel_size=3, padding=1)),
('relu2', nn.ReLU()),
('apply_long', Apply(record_name="input")),
('norm', nn.BatchNorm1d(128)),
('conv3', nn.Conv1d(128, 64, kernel_size=1)),
('apply_short', Apply(record_name="mid", operation="concat")),
]))
Compatibility
Supported Environments
| Environment | Status | Notes |
|---|---|---|
| Single GPU training | ✅ | Full support |
| CPU training | ✅ | Full support |
nn.DataParallel |
✅ | Thread-safe via threading.local() |
DistributedDataParallel |
✅ | Process-safe, recommended for multi-GPU |
| Multi-threaded inference | ✅ | Safe for Flask/FastAPI servers |
| Jupyter notebooks | ✅ | Full support |
torch.jit.script |
❌ | Planned for v1.1 |
| ONNX export | ❌ | Planned for v1.1 |
Thread Safety
torchresidual uses threading.local() for context management, making it safe for:
nn.DataParallel(multiple GPU threads)- Multi-threaded inference servers
- Concurrent requests in production
See docs/DESIGN.md for implementation details.
Design Philosophy
Why thread-local storage?
Traditional approaches store a parent reference in Apply, creating circular references:
ResidualSequential → Apply → ResidualSequential # Breaks pickle/deepcopy
torchresidual uses threading.local() to avoid this:
- ✅ No circular references
- ✅ Works with
pickle,torch.save,deepcopy - ✅ Thread-safe for
nn.DataParallel - ✅ Clean module hierarchy
Why tanh parameterization?
LearnableAlpha uses tanh (not sigmoid) for bounded parameters:
- Better gradient flow near boundaries
- Symmetric around midpoint
- Stable training dynamics
Why auto-detect log space?
For ranges spanning orders of magnitude (e.g., 1e-4 to 1e-1), linear space
poorly explores the lower end. Log space provides uniform coverage:
alpha = LearnableAlpha(0.01, min_value=1e-4, max_value=1.0)
# Automatically uses log space (ratio > 100)
Examples
See examples/ directory:
basic_usage.py- Core conceptsadvanced_usage.py- Advanced conceptslstm_residual.py- Recurrent networks
Contributing
Contributions welcome! Please:
- Fork the repository
- Create a feature branch
- Add tests for new functionality
- Ensure
pytestandmypypass - Submit a pull request
Development setup:
git clone https://github.com/v-garzon/torchresidual.git
cd torchresidual
pip install -e ".[dev]"
pytest tests/
mypy torchresidual/
Citation
If you use torchresidual in your research, please cite:
@software{torchresidual2026,
author = {Garzón, Víctor},
title = {torchresidual: Flexible residual connections for PyTorch},
year = {2026},
url = {https://github.com/v-garzon/torchresidual}
}
License
MIT License - see LICENSE for details.
Changelog
See CHANGELOG.md for version history.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchresidual-0.1.0.tar.gz.
File metadata
- Download URL: torchresidual-0.1.0.tar.gz
- Upload date:
- Size: 26.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b4dd4532b7376ad1b30ca38fc88c869272395e40025ecf3dc1caf4e375801432
|
|
| MD5 |
404161118fec5aa61ac642e7c296df3a
|
|
| BLAKE2b-256 |
5e547af67e8204abb20d6bb45380dd2f49e29f3e4490f16803be7a25a1babf3c
|
File details
Details for the file torchresidual-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torchresidual-0.1.0-py3-none-any.whl
- Upload date:
- Size: 12.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
819fbbf0f5bc329e0654f106fdddbefcd83109df47368de9269bce9a46873ae0
|
|
| MD5 |
4d1308d8ee02c95beb771cee9d029dda
|
|
| BLAKE2b-256 |
21117c783d1fe8df619ddb861b9417996842419282997ec7e2014383680610c4
|