Skip to main content

Modular brain-inspired cognitive augmentation layers for neural networks

Project description

Cognitive Augmentation Layers (cognitive_aug)

Modular, brain-inspired cognitive augmentation layers for neural networks, built atop PyTorch.

This library provides plug-and-play components to endow existing deep learning architectures with biologically inspired capabilities, starting with Global Workspace Theory (GWT) mechanisms for attention, integration, and broad dynamic routing.


Features (Phase v0.1 MVP)

  • Module Registry: Track and coordinate active neural networks acting as specialized cognitive modules.
  • Data Flow Manager: High-performance, framework-integrated routing of latent representations across dynamic computation cycles.
  • Global Workspace Theory (GWT) Layer:
    • Configurable single-slot (high biological fidelity) and multi-slot (engineering focus) workspace bottleneck.
    • Pluggable AttentionSelector supporting Top-Down Key-Query matching and Bottom-Up salience selection.
    • Dynamic Ignition: Non-linear, threshold-gated attentional selection where only highly active features are broadcast.
    • BroadcastEngine that distributes unified cognitive states back to all modules as context.
  • Non-Intrusive Wrappers: ModuleAdapter wraps standard PyTorch nn.Modules using forward/backward hooks without polluting or altering original model classes.
  • Optimized Cognitive Add-ons:
    • Differentiable Selection: CosineSimilaritySelector, VectorizedCrossAttentionSelector (using native FlashAttention speeds), and EfficientGumbelSoftmaxSelector (hard winner-take-all routing).
    • Low-Overhead Salience: MagnitudeSalience (L2 norm), EntropySalience (Shannon entropy confidence), and stateful TemporalSurpriseSalience (temporal cosine distance tracking).
    • Decay Working Memory: DecayWorkingMemory stateful wrapper with in-place exponential decay mutations and blended historical traces.
    • Parallel Downstream Routing: CognitiveOutputRouter broadcasting workspace representations back into multiple output heads in a single vectorized matrix projection pass.

Installation

To install in editable mode with development dependencies:

pip install -e ".[dev]"

Quickstart Example

Here is a simple example showing how to register an image classification module and a text processing module, enabling them to communicate via a shared Global Workspace:

import torch
import torch.nn as nn
from cognitive_aug import CognitiveAugEngine, GlobalWorkspace, ModuleAdapter

# 1. Define your standard PyTorch modules
class VisualModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.fc = nn.Linear(16 * 8 * 8, 256)
    def forward(self, x):
        features = torch.relu(self.conv(x))
        features = features.view(features.size(0), -1)
        return self.fc(features)

class TextModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(1000, 64)
        self.lstm = nn.LSTM(64, 128, batch_first=True)
        self.fc = nn.Linear(128, 256)
    def forward(self, x):
        emb = self.emb(x)
        out, _ = self.lstm(emb)
        return self.fc(out[:, -1, :])

# 2. Instantiate modules
vis_mod = VisualModule()
txt_mod = TextModule()

# 3. Create the Cognitive Engine and Register Modules
engine = CognitiveAugEngine()

# Register modules with their latent dimensions (must match workspace target or be mapped)
vis_adapter = engine.register_module(
    name="vision",
    module=vis_mod,
    latent_dim=256
)

txt_adapter = engine.register_module(
    name="text",
    module=txt_mod,
    latent_dim=256
)

# 4. Instantiate Global Workspace
workspace = GlobalWorkspace(
    latent_dim=256,
    attention_type="key-query",
    ignition_threshold=0.5,
    workspace_slots=1  # Biological single-slot GWT bottleneck
)

# Attach workspace to the engine
engine.attach_workspace(workspace)

# 5. Run standard forward propagation
# Adapters will automatically capture latent representations via PyTorch forward hooks!
dummy_img = torch.randn(4, 3, 8, 8)
dummy_txt = torch.randint(0, 1000, (4, 10))

# Forward pass as usual
vis_out = vis_mod(dummy_img)
txt_out = txt_mod(dummy_txt)

# Perform GWT cycle: Selection & Global Broadcast!
workspace_state = engine.step()

print("Global Workspace broadcast vector shape:", workspace_state.shape)

High-Performance Modular Add-ons

For advanced cognitive systems or performance-critical setups, the package includes highly optimized, zero-copy, fully vectorized components.

1. Advanced Attention Selectors (selectors.py)

Replace the default selector with one of three high-speed subclasses inheriting from BaseSelector:

  • CosineSimilaritySelector: Uses vectorized matrix-based cosine_similarity to rapidly match incoming states to top-down goals.
  • VectorizedCrossAttentionSelector: Uses PyTorch's native scaled_dot_product_attention with the Value tensor designed as an identity matrix, unlocking native FlashAttention speeds while remaining fully differentiable.
  • EfficientGumbelSoftmaxSelector: A hard, winner-take-all routing mechanism using single-pass gumbel_softmax that retains full differentiability.

2. Low-Overhead Salience Metrics (salience.py)

Evaluate modular activation to compute confidence and decide workspace ignition:

  • MagnitudeSalience: Computes L2 norms of states in a single vectorized pass using torch.linalg.vector_norm.
  • EntropySalience: Computes Shannon entropy normalized to $[0, 1]$ confidence, penalizing noisy, unconfident representations.
  • TemporalSurpriseSalience: Stateful cache that detaches gradients across iterations and uses cosine distance to measure step-to-step temporal shifts.
  • global_pool_latent: A dimension-agnostic helper that automatically maps spatial/temporal latent tensor dimensions (e.g. [B, T, D] or [B, C, H, W]) into [B, D] vectors, avoiding sequential loops.

3. Short-Term Decay Working Memory (memory.py)

  • DecayWorkingMemory: Exponentially decays inactive workspace slots in-place (workspace_state.mul_(decay_rate)) and returns a blended context vector of the active new winner and the decaying trace of the past.

4. Parallel Output Router (routing.py)

  • CognitiveOutputRouter: Maps unified workspace representations back into dedicated output heads simultaneously using a single parallel linear projection layer and returns views using zero-copy slicing.

For a full demo showing how to attach and run these components in an active training loop, see example_addons.py.


License

This project is licensed under the MIT License - see the LICENSE details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cognitive_aug-0.2.0.tar.gz (20.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cognitive_aug-0.2.0-py3-none-any.whl (18.4 kB view details)

Uploaded Python 3

File details

Details for the file cognitive_aug-0.2.0.tar.gz.

File metadata

  • Download URL: cognitive_aug-0.2.0.tar.gz
  • Upload date:
  • Size: 20.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for cognitive_aug-0.2.0.tar.gz
Algorithm Hash digest
SHA256 8dd52173d13f2c0e73fc383364650316131b1f539b3534c95a7e9646c7ee2879
MD5 544d15e732b183c84bdeb82c296e563e
BLAKE2b-256 959e028f34e67b595182ff48809a55ab26589dd2ead123699f93f44bb8ddfb5d

See more details on using hashes here.

File details

Details for the file cognitive_aug-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: cognitive_aug-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 18.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for cognitive_aug-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 afa55c25f3c0269a7fd103f2e27d69b7cff6d6c5c6024d65ed871a929e3947e8
MD5 0441aa83e6568168a58d5e0fd9b38b5e
BLAKE2b-256 0e28bfc7b0afb1a6012dc7d11346a57fe6f9779ad30f27fb5ed705ac924a820f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page