Skip to main content

PyTorch convolutional layers with global context conditioning

Project description

PyPI version CI License: GPL v3 Docs

ContextualConv

ContextualConv is a family of custom PyTorch convolutional layers (ContextualConv1d, ContextualConv2d) that support global context conditioning.

These layers behave like standard PyTorch nn.Conv1d and nn.Conv2d, but allow a global vector c to inject per-channel bias into the output, modulating it with contextual information (e.g., class embeddings, latent vectors, etc.).


🔧 Features

  • ⚙️ Drop-in replacement for nn.Conv1d and nn.Conv2d
  • 🧠 Context-aware: injects global vector as output bias
  • 🧱 Based on standard PyTorch convolution
  • 🧠 Optional hidden layer (h_dim) for MLP processing of c
  • 📦 Fully differentiable and unit-tested

📦 Installation

Clone the repo or copy contextual_conv.py into your project, then:

pip install -r requirements.txt

To install PyTorch, follow the official guide: https://pytorch.org/get-started/locally/

Example (CPU only):

pip install torch --index-url https://download.pytorch.org/whl/cpu

🚀 Usage

2D Example (with context and MLP)

from contextual_conv import ContextualConv2d
import torch

conv = ContextualConv2d(
    in_channels=16,
    out_channels=32,
    kernel_size=3,
    padding=1,
    context_dim=10,
    h_dim=64
)

x = torch.randn(8, 16, 32, 32)
c = torch.randn(8, 10)

out = conv(x, c)  # shape: (8, 32, 32, 32)

1D Example (linear context projection)

from contextual_conv import ContextualConv1d

conv = ContextualConv1d(
    in_channels=16,
    out_channels=32,
    kernel_size=5,
    padding=2,
    context_dim=6
)

x = torch.randn(4, 16, 100)
c = torch.randn(4, 6)

out = conv(x, c)  # shape: (4, 32, 100)

Without context

conv = ContextualConv2d(16, 32, kernel_size=3, padding=1)
out = conv(x)  # standard conv2d

📐 Context Vector

  • Shape: (B, context_dim)
  • Passed through a ContextProcessor (either Linear or MLP)
  • Output shape: (B, out_channels) → added as a bias to the output

🧪 Tests

All tests live in tests/test_contextual_conv.py.

Run them with:

pytest tests/

📘 Documentation

Full documentation is available at:

👉 https://contextualconv.readthedocs.io

Includes API reference, architecture explanation, and usage tips.


📄 License

Licensed under GNU GPLv3.


🤝 Contributing

You're welcome to:

  • Add ContextualConv3d
  • Suggest other context conditioning strategies
  • Add notebook examples
  • Improve performance

Open an issue or PR to contribute!


📫 Contact

Questions? Issues? Reach out on GitHub or open a discussion.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

contextual_conv-0.2.0.tar.gz (16.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

contextual_conv-0.2.0-py3-none-any.whl (16.4 kB view details)

Uploaded Python 3

File details

Details for the file contextual_conv-0.2.0.tar.gz.

File metadata

  • Download URL: contextual_conv-0.2.0.tar.gz
  • Upload date:
  • Size: 16.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.13

File hashes

Hashes for contextual_conv-0.2.0.tar.gz
Algorithm Hash digest
SHA256 e52812d619d446129b1af5a928907f2dd8102f8c46c3d37984a2d9eb822ae909
MD5 9fee4d988b2a828669b480d4eff17428
BLAKE2b-256 fb8dd1598f8166abc09144cb922d1d2fa43cf96f1aed77b4f3922e5ad1a09bb8

See more details on using hashes here.

File details

Details for the file contextual_conv-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for contextual_conv-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6002ed4599623c36ca019e48489f147eed03531e4141dd8484da264e63a2cc82
MD5 4186ebe59587fee5dffd3ee126fd3358
BLAKE2b-256 2d30de7d383298726473c5a88a12110b35e6c3c13aecb0cf43d54cb49b787382

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page