PyTorch convolutional layers with global context conditioning
Project description
ContextualConv
ContextualConv is a family of custom PyTorch convolutional layers (ContextualConv1d, ContextualConv2d) that support global context conditioning.
These layers behave like standard PyTorch nn.Conv1d and nn.Conv2d, but allow a global vector c to inject per-channel bias into the output, modulating it with contextual information (e.g., class embeddings, latent vectors, etc.).
🔧 Features
- ⚙️ Drop-in replacement for
nn.Conv1dandnn.Conv2d - 🧠 Context-aware: injects global vector as output bias
- 🧱 Based on standard PyTorch convolution
- 🧠 Optional hidden layer (
h_dim) for MLP processing ofc - 📦 Fully differentiable and unit-tested
📦 Installation
Clone the repo or copy contextual_conv.py into your project, then:
pip install -r requirements.txt
To install PyTorch, follow the official guide: https://pytorch.org/get-started/locally/
Example (CPU only):
pip install torch --index-url https://download.pytorch.org/whl/cpu
🚀 Usage
2D Example (with context and MLP)
from contextual_conv import ContextualConv2d
import torch
conv = ContextualConv2d(
in_channels=16,
out_channels=32,
kernel_size=3,
padding=1,
context_dim=10,
h_dim=64
)
x = torch.randn(8, 16, 32, 32)
c = torch.randn(8, 10)
out = conv(x, c) # shape: (8, 32, 32, 32)
1D Example (linear context projection)
from contextual_conv import ContextualConv1d
conv = ContextualConv1d(
in_channels=16,
out_channels=32,
kernel_size=5,
padding=2,
context_dim=6
)
x = torch.randn(4, 16, 100)
c = torch.randn(4, 6)
out = conv(x, c) # shape: (4, 32, 100)
Without context
conv = ContextualConv2d(16, 32, kernel_size=3, padding=1)
out = conv(x) # standard conv2d
📐 Context Vector
- Shape:
(B, context_dim) - Passed through a
ContextProcessor(eitherLinearorMLP) - Output shape:
(B, out_channels)→ added as a bias to the output
🧪 Tests
All tests live in tests/test_contextual_conv.py.
Run them with:
pytest tests/
📘 Documentation
Full documentation is available at:
👉 https://contextualconv.readthedocs.io
Includes API reference, architecture explanation, and usage tips.
📄 License
Licensed under GNU GPLv3.
🤝 Contributing
You're welcome to:
- Add
ContextualConv3d - Suggest other context conditioning strategies
- Add notebook examples
- Improve performance
Open an issue or PR to contribute!
📫 Contact
Questions? Issues? Reach out on GitHub or open a discussion.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file contextual_conv-0.2.0.tar.gz.
File metadata
- Download URL: contextual_conv-0.2.0.tar.gz
- Upload date:
- Size: 16.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e52812d619d446129b1af5a928907f2dd8102f8c46c3d37984a2d9eb822ae909
|
|
| MD5 |
9fee4d988b2a828669b480d4eff17428
|
|
| BLAKE2b-256 |
fb8dd1598f8166abc09144cb922d1d2fa43cf96f1aed77b4f3922e5ad1a09bb8
|
File details
Details for the file contextual_conv-0.2.0-py3-none-any.whl.
File metadata
- Download URL: contextual_conv-0.2.0-py3-none-any.whl
- Upload date:
- Size: 16.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6002ed4599623c36ca019e48489f147eed03531e4141dd8484da264e63a2cc82
|
|
| MD5 |
4186ebe59587fee5dffd3ee126fd3358
|
|
| BLAKE2b-256 |
2d30de7d383298726473c5a88a12110b35e6c3c13aecb0cf43d54cb49b787382
|