PyTorch convolutional layers with global context conditioning
Project description
ContextualConv
ContextualConv – PyTorch convolutional layers with global context conditioning: per‑channel bias, scale, or modulated FiLM-style scaling.
🚀 Quick start
from contextual_conv import ContextualConv2d
import torch
# FiLM‑style (scale + bias)
conv = ContextualConv2d(
in_channels=16,
out_channels=32,
kernel_size=3,
padding=1,
context_dim=10, # size of global vector c
h_dim=64, # optional MLP hidden dim
use_scale=True, # γ(c)
use_bias=True, # β(c)
scale_mode="film" # or "scale"
)
x = torch.randn(8, 16, 32, 32) # feature map
c = torch.randn(8, 10) # context vector
out = conv(x, c) # shape: (8, 32, 32, 32)
Modes at a glance
use_scale |
use_bias |
scale_mode |
Behaviour |
|---|---|---|---|
False |
True |
– | Contextual bias only |
True |
False |
"scale" |
Scale only: out * γ |
True |
True |
"film" |
FiLM: out * (1 + γ) + β |
True |
True |
"scale" |
Scale + shift: out * γ + β |
False |
False |
– | Plain convolution (no modulation) |
If context_dim is provided, at least one of use_scale or use_bias must be True.
🔧 Key features
- ⚙️ Drop‑in replacement for
nn.Conv1d/nn.Conv2d
→ Same arguments + optional context options. - 🧠 Global vector conditioning via learnable γ(c) and/or β(c)
- 🔀 Modulation modes:
scale_mode="film":out * (1 + γ)scale_mode="scale":out * γ
- 🪶 Lightweight – one small MLP (or single
Linear) per layer - 🧑🔬 FiLM ready – reproduce Feature‑wise Linear Modulation with two lines
- 🧩 Modular – combine with any architecture, works on CPU / GPU
- 📤 Infer context vectors from unmodulated outputs with
.infer_context() - ✅ Unit‑tested and documented
📦 Installation
pip install contextual-conv # version 0.6.3 on PyPI
Or install from source:
git clone https://github.com/abbassix/ContextualConv.git
cd ContextualConv
pip install -e .[dev]
📐 Context vector details
- Shape:
(B, context_dim)
(one global descriptor per sample – class label embedding, latent code, etc.) - Processed by a
ContextProcessor:Linear(context_dim, out_dim)(bias‑only / scale‑only)- Small MLP if
h_dimis set.
- Output dims:
out_channels→ bias or scale2 × out_channels→ FiLM (scale + bias)
🔎 Context inference
You can extract the context vector inferred from the output using:
context = conv.infer_context(x)
To also get the unmodulated output from the convolution layer:
context, raw_out = conv.infer_context(x, return_raw_output=True)
This is useful when you need both the input’s context and its original unmodulated features.
🧪 Running tests
Run the full test suite with coverage:
pytest --cov=contextual_conv --cov-report=term-missing
📘 Documentation
Full API reference & tutorials: https://contextualconv.readthedocs.io
🤝 Contributing
Bug reports, feature requests, and PRs are welcome! See CONTRIBUTING.md.
📄 License
GNU GPLv3 – see LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file contextual_conv-0.6.3.tar.gz.
File metadata
- Download URL: contextual_conv-0.6.3.tar.gz
- Upload date:
- Size: 19.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6d49f4c3b22101957fae3114286c4171daec1eddfd58b9e7895ed0dd8a1f075b
|
|
| MD5 |
5c3af98a21df6c9d16164de6e2993777
|
|
| BLAKE2b-256 |
bc97b6ec4bb855e18e471206c39c12f955f6aa1a89a22bdb0cca98e06aeb120a
|
File details
Details for the file contextual_conv-0.6.3-py3-none-any.whl.
File metadata
- Download URL: contextual_conv-0.6.3-py3-none-any.whl
- Upload date:
- Size: 18.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
40f1a10f1cd7d270dfabdf5e33ea9bc748e374736cc63b74a3a5965078a7c251
|
|
| MD5 |
82a6eec563402ad77c3ed632464eb78f
|
|
| BLAKE2b-256 |
5040f104d6e09b6358538bb27675dacfcc53066a4e3549ae88c7d66f6c4bb96d
|