Skip to main content

Flexible residual connections for PyTorch

Project description

torchresidual

PyPI version Python versions Tests License: MIT

Flexible residual connections for PyTorch with a clean, composable API.

Build complex residual architectures without boilerplate. torchresidual provides Record and Apply modules that let you create skip connections of any depth, with automatic shape handling and learnable mixing coefficients.


📖 Quick Start | 📚 Full Documentation | 💡 Examples | ❓ FAQ


Why torchresidual?

Standard PyTorch:

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)
  
    def forward(self, x):
        residual = x
        x = self.linear(x)
        x = F.relu(x)
        x = self.norm(x)
        return x + residual  # Manual residual

With torchresidual:

block = ResidualSequential(
    Record(name="input"),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.LayerNorm(64),
    Apply(record_name="input"),  # Automatic residual
)

Benefits:

  • No custom forward() methods
  • Multiple skip connections with named records
  • Automatic projection when dimensions change
  • Five residual operations (add, concat, multiply, gated, highway)
  • Learnable mixing coefficients
  • Works with LSTMs, attention, and any nn.Module

Installation

pip install torchresidual

Requirements: Python ≥3.8, PyTorch ≥1.9

New to torchresidual? See the Quick Start Guide for a 5-minute tutorial.


Quick Start

Basic residual connection

import torch
import torch.nn as nn
from torchresidual import ResidualSequential, Record, Apply

block = ResidualSequential(
    Record(name="input"),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.LayerNorm(64),
    Apply(record_name="input", operation="add"),
)

x = torch.randn(8, 64)
out = block(x)  # Shape: [8, 64]

Multiple skip connections

block = ResidualSequential(
    Record(name="input", need_projection=True),
    nn.Linear(64, 32),
    nn.ReLU(),
    Record(name="mid"),
    nn.Linear(32, 64),
    Apply(record_name="input"),      # Long skip with projection
    nn.LayerNorm(64),
    nn.Linear(64, 32),
    Apply(record_name="mid"),         # Short skip
)

Learnable mixing coefficient

from torchresidual import LearnableAlpha

block = ResidualSequential(
    Record(name="r"),
    nn.Linear(64, 64),
    Apply(
        record_name="r", 
        operation="gated",
        alpha=LearnableAlpha(0.3, min_value=0.0, max_value=1.0)
    ),
)

# Alpha is learned during training
optimizer = torch.optim.Adam(block.parameters(), lr=1e-3)

Automatic projection for shape changes

# Input: [batch, 64] → Output: [batch, 128]
block = ResidualSequential(
    Record(name="r", need_projection=True),  # Enables auto-projection
    nn.Linear(64, 128),
    nn.ReLU(),
    Apply(record_name="r"),  # Automatically projects 64→128
)

LSTM with residual

from torchresidual import RecurrentWrapper

block = ResidualSequential(
    Record(name="r"),
    RecurrentWrapper(
        nn.LSTM(32, 32, num_layers=2, batch_first=True),
        return_hidden=False
    ),
    Apply(record_name="r"),
)

x = torch.randn(4, 10, 32)  # [batch, seq_len, features]
out = block(x)

API Reference

Core Components

ResidualSequential(*modules)

Drop-in replacement for nn.Sequential with residual connection support.

Example:

block = ResidualSequential(
    nn.Linear(64, 64),
    Record(),
    nn.ReLU(),
    Apply(),
)

Record(need_projection=False, name=None)

Saves the current tensor for later use in a residual connection.

Args:

  • need_projection (bool): If True, Apply will create a linear projection when shapes don't match
  • name (str, optional): Label for this record point. Auto-assigned if None.

Example:

Record(name="input", need_projection=True)

Apply(operation="add", record_name=None, alpha=1.0)

Applies a residual connection using a previously recorded tensor.

Args:

  • operation (str): One of "add", "concat", "multiply", "gated", "highway"
  • record_name (str, optional): Which Record to use. If None, uses most recent.
  • alpha (float or LearnableAlpha): Scaling factor for residual branch

Operations:

Operation Formula Use case
add x + α·r Standard ResNet-style
concat cat([x, r], dim=-1) DenseNet-style
multiply x·(1 + α·r) Multiplicative skip
gated (1-α)·x + α·r Learnable interpolation
highway T·x + C·r Highway Networks

Example:

Apply(operation="gated", record_name="input", alpha=0.5)

LearnableAlpha(initial_value, min_value=0.0, max_value=1.0, use_log_space=None)

Learnable scalar parameter constrained to [min_value, max_value].

Args:

  • initial_value (float): Starting value
  • min_value (float): Lower bound (inclusive)
  • max_value (float): Upper bound (inclusive)
  • use_log_space (bool, optional): Force log or linear parameterization. Auto-detected if None.

Example:

alpha = LearnableAlpha(0.5, min_value=0.0, max_value=1.0)
x = x + alpha() * residual  # alpha() returns constrained value

RecurrentWrapper(module, return_hidden=False)

Wraps LSTM/GRU modules for seamless integration with ResidualSequential.

Args:

  • module (nn.Module): The recurrent module (e.g., nn.LSTM)
  • return_hidden (bool): If True, returns (output, hidden) tuple

Example:

RecurrentWrapper(nn.LSTM(64, 64, batch_first=True), return_hidden=False)

Advanced Examples

Transformer-style block

# Multi-head attention with residual and layer norm
block = ResidualSequential(
    Record(name="input"),
    nn.MultiheadAttention(embed_dim=256, num_heads=8),
    Apply(record_name="input"),
    nn.LayerNorm(256),
  
    Record(name="attn_out"),
    nn.Linear(256, 1024),
    nn.ReLU(),
    nn.Linear(1024, 256),
    Apply(record_name="attn_out"),
    nn.LayerNorm(256),
)

Nested residual blocks

inner_block = ResidualSequential(
    Record(),
    nn.Linear(64, 64),
    nn.ReLU(),
    Apply(),
)

outer_block = ResidualSequential(
    Record(),
    inner_block,
    nn.Linear(64, 64),
    Apply(),
)

Complex encoder block

from collections import OrderedDict

encoder = ResidualSequential(OrderedDict([
    ('record_input', Record(need_projection=True, name="input")),
    ('conv1', nn.Conv1d(64, 128, kernel_size=3, padding=1)),
    ('relu1', nn.ReLU()),
    ('record_mid', Record(name="mid")),
    ('conv2', nn.Conv1d(128, 128, kernel_size=3, padding=1)),
    ('relu2', nn.ReLU()),
    ('apply_long', Apply(record_name="input")),
    ('norm', nn.BatchNorm1d(128)),
    ('conv3', nn.Conv1d(128, 64, kernel_size=1)),
    ('apply_short', Apply(record_name="mid", operation="concat")),
]))

Compatibility

Supported Environments

Environment Status Notes
Single GPU training Full support
CPU training Full support
nn.DataParallel Thread-safe via threading.local()
DistributedDataParallel Process-safe, recommended for multi-GPU
Multi-threaded inference Safe for Flask/FastAPI servers
Jupyter notebooks Full support
torch.jit.script Planned for v1.1
ONNX export Planned for v1.1

Thread Safety

torchresidual uses threading.local() for context management, making it safe for:

  • nn.DataParallel (multiple GPU threads)
  • Multi-threaded inference servers
  • Concurrent requests in production

See docs/DESIGN.md for implementation details.


Design Philosophy

Why thread-local storage?

Traditional approaches store a parent reference in Apply, creating circular references:

ResidualSequential → Apply → ResidualSequential  # Breaks pickle/deepcopy

torchresidual uses threading.local() to avoid this:

  • ✅ No circular references
  • ✅ Works with pickle, torch.save, deepcopy
  • ✅ Thread-safe for nn.DataParallel
  • ✅ Clean module hierarchy

Why tanh parameterization?

LearnableAlpha uses tanh (not sigmoid) for bounded parameters:

  • Better gradient flow near boundaries
  • Symmetric around midpoint
  • Stable training dynamics

Why auto-detect log space?

For ranges spanning orders of magnitude (e.g., 1e-4 to 1e-1), linear space poorly explores the lower end. Log space provides uniform coverage:

alpha = LearnableAlpha(0.01, min_value=1e-4, max_value=1.0)
# Automatically uses log space (ratio > 100)

Examples

See examples/ directory:


Contributing

Contributions welcome! Please:

  1. Fork the repository
  2. Create a feature branch
  3. Add tests for new functionality
  4. Ensure pytest and mypy pass
  5. Submit a pull request

Development setup:

git clone https://github.com/v-garzon/torchresidual.git
cd torchresidual
pip install -e ".[dev]"
pytest tests/
mypy torchresidual/

Citation

If you use torchresidual in your research, please cite:

@software{torchresidual2026,
  author = {Garzón, Víctor},
  title = {torchresidual: Flexible residual connections for PyTorch},
  year = {2026},
  url = {https://github.com/v-garzon/torchresidual}
}

License

MIT License - see LICENSE for details.


Changelog

See CHANGELOG.md for version history.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchresidual-0.1.0.tar.gz (26.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchresidual-0.1.0-py3-none-any.whl (12.9 kB view details)

Uploaded Python 3

File details

Details for the file torchresidual-0.1.0.tar.gz.

File metadata

  • Download URL: torchresidual-0.1.0.tar.gz
  • Upload date:
  • Size: 26.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for torchresidual-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b4dd4532b7376ad1b30ca38fc88c869272395e40025ecf3dc1caf4e375801432
MD5 404161118fec5aa61ac642e7c296df3a
BLAKE2b-256 5e547af67e8204abb20d6bb45380dd2f49e29f3e4490f16803be7a25a1babf3c

See more details on using hashes here.

File details

Details for the file torchresidual-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torchresidual-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 12.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for torchresidual-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 819fbbf0f5bc329e0654f106fdddbefcd83109df47368de9269bce9a46873ae0
MD5 4d1308d8ee02c95beb771cee9d029dda
BLAKE2b-256 21117c783d1fe8df619ddb861b9417996842419282997ec7e2014383680610c4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page