Skip to main content

Modular brain-inspired cognitive augmentation layers for neural networks

Project description

Cognitive Augmentation Layers (cognitive_aug)

Modular, brain-inspired cognitive augmentation layers for neural networks, built atop PyTorch.

This library provides plug-and-play components to endow existing deep learning architectures with biologically inspired capabilities, starting with Global Workspace Theory (GWT) mechanisms for attention, integration, and broad dynamic routing.


Features (Phase v0.1 MVP)

  • Module Registry: Track and coordinate active neural networks acting as specialized cognitive modules.
  • Data Flow Manager: High-performance, framework-integrated routing of latent representations across dynamic computation cycles.
  • Global Workspace Theory (GWT) Layer:
    • Configurable single-slot (high biological fidelity) and multi-slot (engineering focus) workspace bottleneck.
    • Pluggable AttentionSelector supporting Top-Down Key-Query matching and Bottom-Up salience selection.
    • Dynamic Ignition: Non-linear, threshold-gated attentional selection where only highly active features are broadcast.
    • BroadcastEngine that distributes unified cognitive states back to all modules as context.
  • Non-Intrusive Wrappers: ModuleAdapter wraps standard PyTorch nn.Modules using forward/backward hooks without polluting or altering original model classes.
  • Optimized Cognitive Add-ons:
    • Differentiable Selection: CosineSimilaritySelector, VectorizedCrossAttentionSelector (using native FlashAttention speeds), and EfficientGumbelSoftmaxSelector (hard winner-take-all routing).
    • Low-Overhead Salience: MagnitudeSalience (L2 norm), EntropySalience (Shannon entropy confidence), and stateful TemporalSurpriseSalience (temporal cosine distance tracking).
    • Decay Working Memory: DecayWorkingMemory stateful wrapper with in-place exponential decay mutations and blended historical traces.
    • Parallel Downstream Routing: CognitiveOutputRouter broadcasting workspace representations back into multiple output heads in a single vectorized matrix projection pass.

Installation

To install in editable mode with development dependencies:

pip install -e ".[dev]"

Quickstart Example

Here is a simple example showing how to register an image classification module and a text processing module, enabling them to communicate via a shared Global Workspace:

import torch
import torch.nn as nn
from cognitive_aug import CognitiveAugEngine, GlobalWorkspace, ModuleAdapter

# 1. Define your standard PyTorch modules
class VisualModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.fc = nn.Linear(16 * 8 * 8, 256)
    def forward(self, x):
        features = torch.relu(self.conv(x))
        features = features.view(features.size(0), -1)
        return self.fc(features)

class TextModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(1000, 64)
        self.lstm = nn.LSTM(64, 128, batch_first=True)
        self.fc = nn.Linear(128, 256)
    def forward(self, x):
        emb = self.emb(x)
        out, _ = self.lstm(emb)
        return self.fc(out[:, -1, :])

# 2. Instantiate modules
vis_mod = VisualModule()
txt_mod = TextModule()

# 3. Create the Cognitive Engine and Register Modules
engine = CognitiveAugEngine()

# Register modules with their latent dimensions (must match workspace target or be mapped)
vis_adapter = engine.register_module(
    name="vision",
    module=vis_mod,
    latent_dim=256
)

txt_adapter = engine.register_module(
    name="text",
    module=txt_mod,
    latent_dim=256
)

# 4. Instantiate Global Workspace
workspace = GlobalWorkspace(
    latent_dim=256,
    attention_type="key-query",
    ignition_threshold=0.5,
    workspace_slots=1  # Biological single-slot GWT bottleneck
)

# Attach workspace to the engine
engine.attach_workspace(workspace)

# 5. Run standard forward propagation
# Adapters will automatically capture latent representations via PyTorch forward hooks!
dummy_img = torch.randn(4, 3, 8, 8)
dummy_txt = torch.randint(0, 1000, (4, 10))

# Forward pass as usual
vis_out = vis_mod(dummy_img)
txt_out = txt_mod(dummy_txt)

# Perform GWT cycle: Selection & Global Broadcast!
workspace_state = engine.step()

print("Global Workspace broadcast vector shape:", workspace_state.shape)

High-Performance Modular Add-ons

For advanced cognitive systems or performance-critical setups, the package includes highly optimized, zero-copy, fully vectorized components.

1. Advanced Attention Selectors (selectors.py)

Replace the default selector with one of three high-speed subclasses inheriting from BaseSelector:

  • CosineSimilaritySelector: Uses vectorized matrix-based cosine_similarity to rapidly match incoming states to top-down goals.
  • VectorizedCrossAttentionSelector: Uses PyTorch's native scaled_dot_product_attention with the Value tensor designed as an identity matrix, unlocking native FlashAttention speeds while remaining fully differentiable.
  • EfficientGumbelSoftmaxSelector: A hard, winner-take-all routing mechanism using single-pass gumbel_softmax that retains full differentiability.

2. Low-Overhead Salience Metrics (salience.py)

Evaluate modular activation to compute confidence and decide workspace ignition:

  • MagnitudeSalience: Computes L2 norms of states in a single vectorized pass using torch.linalg.vector_norm.
  • EntropySalience: Computes Shannon entropy normalized to $[0, 1]$ confidence, penalizing noisy, unconfident representations.
  • TemporalSurpriseSalience: Stateful cache that detaches gradients across iterations and uses cosine distance to measure step-to-step temporal shifts.
  • global_pool_latent: A dimension-agnostic helper that automatically maps spatial/temporal latent tensor dimensions (e.g. [B, T, D] or [B, C, H, W]) into [B, D] vectors, avoiding sequential loops.

3. Short-Term Decay Working Memory (memory.py)

  • DecayWorkingMemory: Exponentially decays inactive workspace slots in-place (workspace_state.mul_(decay_rate)) and returns a blended context vector of the active new winner and the decaying trace of the past.

4. Parallel Output Router (routing.py)

  • CognitiveOutputRouter: Maps unified workspace representations back into dedicated output heads simultaneously using a single parallel linear projection layer and returns views using zero-copy slicing.

For a full demo showing how to attach and run these components in an active training loop, see example_addons.py.


License

This project is licensed under the MIT License - see the LICENSE details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cognitive_aug-0.2.1.tar.gz (2.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cognitive_aug-0.2.1-py3-none-any.whl (24.5 kB view details)

Uploaded Python 3

File details

Details for the file cognitive_aug-0.2.1.tar.gz.

File metadata

  • Download URL: cognitive_aug-0.2.1.tar.gz
  • Upload date:
  • Size: 2.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for cognitive_aug-0.2.1.tar.gz
Algorithm Hash digest
SHA256 0bc7b0394daadd0e3322dcd27772bd081375f7de01dd6753fa089de38c262367
MD5 a625234d91dfb86de3a685d5d3768693
BLAKE2b-256 55f40d8b98c5953dace0772ee70c835aa0adaae40e87b85ce8f731a5694ab6ef

See more details on using hashes here.

File details

Details for the file cognitive_aug-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: cognitive_aug-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 24.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for cognitive_aug-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a9f65e0d9b8255f4efb53d66050ffa66aab84d83ebe8367a34a33720aa46c35a
MD5 e1b8f58c04eaf07f55ff2b2e9a2732b6
BLAKE2b-256 f338b59a5cd02fea7c110c4431f25737d9dd03807e45718d47c605d652859a3a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page