Skip to main content

Modular brain-inspired cognitive augmentation layers for neural networks

Project description

Cognitive Augmentation Layers (cognitive_aug)

Modular, brain-inspired cognitive augmentation layers for neural networks, built atop PyTorch.

This library provides plug-and-play components to endow existing deep learning architectures with biologically inspired capabilities, starting with Global Workspace Theory (GWT) mechanisms for attention, integration, and broad dynamic routing.


Features (Phase v0.2 - Active Dendritic Gating)

  • Module Registry: Track and coordinate active neural networks acting as specialized cognitive modules.
  • Data Flow Manager: High-performance, framework-integrated routing of latent representations across dynamic computation cycles.
  • Global Workspace Theory (GWT) Layer:
    • Configurable single-slot (high biological fidelity) and multi-slot (engineering focus) workspace bottleneck.
    • Pluggable AttentionSelector supporting Top-Down Key-Query matching and Bottom-Up salience selection.
    • Dynamic Ignition: Non-linear, threshold-gated attentional selection where only highly active features are broadcast.
    • BroadcastEngine that distributes unified cognitive states back to all modules as context.
  • Active Dendritic Gating (Phase v0.2):
    • ActiveDendriteGate: Vectorized dendritic pre-processing that uses GWT context to modulate feedforward features.
    • Modulatory Gain: Smooth sigmoidal scaling mapping context directly to features.
    • NMDA Threshold Spiking: Sharp thresholding mimicking biological NMDA spikes (zeroing out inactive branches) with Straight-Through Estimators (STE) to ensure clean backpropagation.
    • Telemetry Dashboard: Dynamic pathway inspection and active/muted statistics.
  • Non-Intrusive Wrappers: ModuleAdapter wraps standard PyTorch nn.Modules using forward/backward hooks without polluting or altering original model classes.
  • Optimized Cognitive Add-ons:
    • Differentiable Selection: CosineSimilaritySelector, VectorizedCrossAttentionSelector (using native FlashAttention speeds), and EfficientGumbelSoftmaxSelector (hard winner-take-all routing).
    • Low-Overhead Salience: MagnitudeSalience (L2 norm), EntropySalience (Shannon entropy confidence), and stateful TemporalSurpriseSalience (temporal cosine distance tracking).
    • Decay Working Memory: DecayWorkingMemory stateful wrapper with in-place exponential decay mutations and blended historical traces.
    • Parallel Downstream Routing: CognitiveOutputRouter broadcasting workspace representations back into multiple output heads in a single vectorized matrix projection pass.

Installation

To install in editable mode with development dependencies:

pip install -e ".[dev]"

Quickstart Example

Here is a simple example showing how to register an image classification module and a text processing module, enabling them to communicate via a shared Global Workspace:

import torch
import torch.nn as nn
from cognitive_aug import CognitiveAugEngine, GlobalWorkspace, ModuleAdapter

# 1. Define your standard PyTorch modules
class VisualModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.fc = nn.Linear(16 * 8 * 8, 256)
    def forward(self, x):
        features = torch.relu(self.conv(x))
        features = features.view(features.size(0), -1)
        return self.fc(features)

class TextModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(1000, 64)
        self.lstm = nn.LSTM(64, 128, batch_first=True)
        self.fc = nn.Linear(128, 256)
    def forward(self, x):
        emb = self.emb(x)
        out, _ = self.lstm(emb)
        return self.fc(out[:, -1, :])

# 2. Instantiate modules
vis_mod = VisualModule()
txt_mod = TextModule()

# 3. Create the Cognitive Engine and Register Modules
engine = CognitiveAugEngine()

# Register modules with their latent dimensions (must match workspace target or be mapped)
vis_adapter = engine.register_module(
    name="vision",
    module=vis_mod,
    latent_dim=256
)

txt_adapter = engine.register_module(
    name="text",
    module=txt_mod,
    latent_dim=256
)

# 4. Instantiate Global Workspace
workspace = GlobalWorkspace(
    latent_dim=256,
    attention_type="key-query",
    ignition_threshold=0.5,
    workspace_slots=1  # Biological single-slot GWT bottleneck
)

# Attach workspace to the engine
engine.attach_workspace(workspace)

# 5. Run standard forward propagation
# Adapters will automatically capture latent representations via PyTorch forward hooks!
dummy_img = torch.randn(4, 3, 8, 8)
dummy_txt = torch.randint(0, 1000, (4, 10))

# Forward pass as usual
vis_out = vis_mod(dummy_img)
txt_out = txt_mod(dummy_txt)

# Perform GWT cycle: Selection & Global Broadcast!
workspace_state = engine.step()

print("Global Workspace broadcast vector shape:", workspace_state.shape)

High-Performance Modular Add-ons

For advanced cognitive systems or performance-critical setups, the package includes highly optimized, zero-copy, fully vectorized components.

1. Advanced Attention Selectors (selectors.py)

Replace the default selector with one of three high-speed subclasses inheriting from BaseSelector:

  • CosineSimilaritySelector: Uses vectorized matrix-based cosine_similarity to rapidly match incoming states to top-down goals.
  • VectorizedCrossAttentionSelector: Uses PyTorch's native scaled_dot_product_attention with the Value tensor designed as an identity matrix, unlocking native FlashAttention speeds while remaining fully differentiable.
  • EfficientGumbelSoftmaxSelector: A hard, winner-take-all routing mechanism using single-pass gumbel_softmax that retains full differentiability.

2. Low-Overhead Salience Metrics (salience.py)

Evaluate modular activation to compute confidence and decide workspace ignition:

  • MagnitudeSalience: Computes L2 norms of states in a single vectorized pass using torch.linalg.vector_norm.
  • EntropySalience: Computes Shannon entropy normalized to $[0, 1]$ confidence, penalizing noisy, unconfident representations.
  • TemporalSurpriseSalience: Stateful cache that detaches gradients across iterations and uses cosine distance to measure step-to-step temporal shifts.
  • global_pool_latent: A dimension-agnostic helper that automatically maps spatial/temporal latent tensor dimensions (e.g. [B, T, D] or [B, C, H, W]) into [B, D] vectors, avoiding sequential loops.

3. Short-Term Decay Working Memory (memory.py)

  • DecayWorkingMemory: Exponentially decays inactive workspace slots in-place (workspace_state.mul_(decay_rate)) and returns a blended context vector of the active new winner and the decaying trace of the past.

4. Parallel Output Router (routing.py)

  • CognitiveOutputRouter: Maps unified workspace representations back into dedicated output heads simultaneously using a single parallel linear projection layer and returns views using zero-copy slicing.

5. Active Dendritic Gating (gwt/dendrite.py) [Phase v0.2]

Biologically-inspired active dendritic pre-processors that dynamically modulate local feedforward pathways using GWT context:

  • ActiveDendriteGate: Fully vectorized module that projects context [B, context_dim] onto dendritic branches.
    • "modulatory-gain": Smooth sigmoidal contextual scaling.
    • "nmda-threshold": Sharp thresholded biological spiking. Spikes if local depolarization $\ge \text{threshold}$, otherwise zeroed out. Uses a Straight-Through Estimator (STE) for clean training gradient flow.
  • DendriticModuleAdapter: Special subclass of ModuleAdapter that automatically detects model output dimensions statically (or dynamically on the first forward pass) and seamlessly appends dendritic context-gating onto the module's execution hook pipeline.
  • get_dendritic_status: Telemetry helper scanning modules recursively to report active vs. muted pathway percentages across all dendritic gates.

For a full demo showing how to attach and run these components in an active training loop, see example_addons.py.


License

This project is licensed under the MIT License - see the LICENSE details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cognitive_aug-0.2.2.tar.gz (2.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cognitive_aug-0.2.2-py3-none-any.whl (25.1 kB view details)

Uploaded Python 3

File details

Details for the file cognitive_aug-0.2.2.tar.gz.

File metadata

  • Download URL: cognitive_aug-0.2.2.tar.gz
  • Upload date:
  • Size: 2.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for cognitive_aug-0.2.2.tar.gz
Algorithm Hash digest
SHA256 923b3b30cfe9c8eb89c01e13f6acfd8cf557de7a5428263250e3440c8a0c1089
MD5 88f7ed05e1c8347e9613c09ec3905b15
BLAKE2b-256 f12c145482f861f22de050539272cf279a9af401e7638cc90b87eadff3c5e83c

See more details on using hashes here.

File details

Details for the file cognitive_aug-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: cognitive_aug-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 25.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for cognitive_aug-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7665ba747d699a7ec7d27a01e7e4aa5ed4b3a4a1746445c28d58a2fee7f5e67b
MD5 b8fa5e2157ab01af998c610d64eb927a
BLAKE2b-256 5dc0c90eaecac8573c2e4bd3d458b6a614d69848974ced47dea0716843291759

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page