Skip to main content

PyTorch convolutional layers with global context conditioning

Project description

ContextualConv

ContextualConv is a family of custom PyTorch convolutional layers (ContextualConv1d, ContextualConv2d) that support global context conditioning.

These layers mimic standard PyTorch convolutions using im2col + matrix multiplication, while allowing a global context vector c to modulate the output at all spatial or temporal positions.


🔧 Features

  • ⚙️ Drop-in replacement for nn.Conv1d and nn.Conv2d
  • 🧠 Context-aware: injects global information into every location
  • 🧱 Uses unfold (im2col) to compute convolution explicitly
  • 📦 Fully differentiable, supports grouped convolutions

📦 Installation

Clone the repo or copy contextual_conv.py into your project, then install dependencies:

pip install -r requirements.txt

Then install the correct PyTorch version for your system (CPU or CUDA) by following the official instructions:

🔗 https://pytorch.org/get-started/locally/

Examples:

  • CPU-only:

    pip install torch --index-url https://download.pytorch.org/whl/cpu
    
  • CUDA 11.8:

    pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    

🚀 Usage

2D Example

import torch
from contextual_conv import ContextualConv2d

conv2d = ContextualConv2d(
    in_channels=16,
    out_channels=32,
    kernel_size=3,
    padding=1,
    c_dim=10  # context dimensionality
)

x = torch.randn(8, 16, 32, 32)
c = torch.randn(8, 10)

out = conv2d(x, c)  # shape: (8, 32, 32, 32)

1D Example

from contextual_conv import ContextualConv1d

conv1d = ContextualConv1d(
    in_channels=16,
    out_channels=32,
    kernel_size=5,
    padding=2,
    c_dim=6
)

x = torch.randn(4, 16, 100)
c = torch.randn(4, 6)

out = conv1d(x, c)  # shape: (4, 32, 100)

Without context

conv = ContextualConv2d(16, 32, kernel_size=3, padding=1)
out = conv(x)  # works even without `c`

📐 Context Vector

  • Shape: (N, c_dim) or (N, 1, c_dim)
  • Broadcasted to all positions (spatial or temporal)
  • Concatenated to each unfolded input patch
  • Modulated by learnable c_weight before being added to the convolution output

🔍 When to Use

Use ContextualConv layers when:

  • You want to inject external or global information into feature maps
  • You need interpretable and customizable convolution logic
  • You want context-aware dynamic filtering with no extra spatial modeling

🧪 Tests

Unit tests are included in tests/test_contextual_conv.py.

✅ To run the tests:

pip install -r requirements.txt
pip install torch  # see installation instructions above
pytest tests/

The tests compare ContextualConv1d and ContextualConv2d against standard PyTorch layers with context disabled, ensuring correctness.


🤖 GitHub Actions (CI)

A GitHub Actions workflow in .github/workflows/test.yml automatically runs tests on push and pull requests using the CPU version of PyTorch.


📄 License

GNU GPLv3 License


🤝 Contributing

Contributions welcome! You can:

  • Add ContextualConv3d support
  • Improve performance (e.g., using einsum)
  • Write more advanced tests or benchmarks
  • Create useful networks that use these layers

Open a pull request or issue to get started.


📫 Contact

Questions or suggestions? Open an issue or reach out via GitHub.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

contextual_conv-0.1.0.tar.gz (17.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

contextual_conv-0.1.0-py3-none-any.whl (17.9 kB view details)

Uploaded Python 3

File details

Details for the file contextual_conv-0.1.0.tar.gz.

File metadata

  • Download URL: contextual_conv-0.1.0.tar.gz
  • Upload date:
  • Size: 17.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for contextual_conv-0.1.0.tar.gz
Algorithm Hash digest
SHA256 2dccd0428da9dad193afdffd75b2e459709b504afcd9f78e1a7c7d64c538cdb5
MD5 344fc8277e32fafe711645666312dc57
BLAKE2b-256 01c1a78ee5b12e36596ed49567a277aead917ebca09af099e129f3dc8aa98176

See more details on using hashes here.

File details

Details for the file contextual_conv-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for contextual_conv-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e8b452b07fd817d948af66be193e260ff8132757d19e08620939be123f4845d3
MD5 cd1f81ad3febf823da7aa2d91da02922
BLAKE2b-256 a928e44688ca2df11398a41759cc35cb334568889d0bd9b9349104834cde4ed2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page