Modular brain-inspired cognitive augmentation layers for neural networks
Project description
Cognitive Augmentation Layers (cognitive_aug)
Modular, brain-inspired cognitive augmentation layers for neural networks, built atop PyTorch.
This library provides plug-and-play components to endow existing deep learning architectures with biologically inspired capabilities, starting with Global Workspace Theory (GWT) mechanisms for attention, integration, and broad dynamic routing.
Features (Phase v0.2 - Active Dendritic Gating)
- Module Registry: Track and coordinate active neural networks acting as specialized cognitive modules.
- Data Flow Manager: High-performance, framework-integrated routing of latent representations across dynamic computation cycles.
- Global Workspace Theory (GWT) Layer:
- Configurable
single-slot(high biological fidelity) andmulti-slot(engineering focus) workspace bottleneck. - Pluggable
AttentionSelectorsupporting Top-Down Key-Query matching and Bottom-Up salience selection. - Dynamic Ignition: Non-linear, threshold-gated attentional selection where only highly active features are broadcast.
BroadcastEnginethat distributes unified cognitive states back to all modules as context.
- Configurable
- Active Dendritic Gating (Phase v0.2):
ActiveDendriteGate: Vectorized dendritic pre-processing that uses GWT context to modulate feedforward features.- Modulatory Gain: Smooth sigmoidal scaling mapping context directly to features.
- NMDA Threshold Spiking: Sharp thresholding mimicking biological NMDA spikes (zeroing out inactive branches) with Straight-Through Estimators (STE) to ensure clean backpropagation.
- Telemetry Dashboard: Dynamic pathway inspection and active/muted statistics.
- Non-Intrusive Wrappers:
ModuleAdapterwraps standard PyTorchnn.Modules using forward/backward hooks without polluting or altering original model classes. - Optimized Cognitive Add-ons:
- Differentiable Selection:
CosineSimilaritySelector,VectorizedCrossAttentionSelector(using native FlashAttention speeds), andEfficientGumbelSoftmaxSelector(hard winner-take-all routing). - Low-Overhead Salience:
MagnitudeSalience(L2 norm),EntropySalience(Shannon entropy confidence), and statefulTemporalSurpriseSalience(temporal cosine distance tracking). - Decay Working Memory:
DecayWorkingMemorystateful wrapper with in-place exponential decay mutations and blended historical traces. - Parallel Downstream Routing:
CognitiveOutputRouterbroadcasting workspace representations back into multiple output heads in a single vectorized matrix projection pass.
- Differentiable Selection:
Installation
To install in editable mode with development dependencies:
pip install -e ".[dev]"
Quickstart Example
Here is a simple example showing how to register an image classification module and a text processing module, enabling them to communicate via a shared Global Workspace:
import torch
import torch.nn as nn
from cognitive_aug import CognitiveAugEngine, GlobalWorkspace, ModuleAdapter
# 1. Define your standard PyTorch modules
class VisualModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
self.fc = nn.Linear(16 * 8 * 8, 256)
def forward(self, x):
features = torch.relu(self.conv(x))
features = features.view(features.size(0), -1)
return self.fc(features)
class TextModule(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(1000, 64)
self.lstm = nn.LSTM(64, 128, batch_first=True)
self.fc = nn.Linear(128, 256)
def forward(self, x):
emb = self.emb(x)
out, _ = self.lstm(emb)
return self.fc(out[:, -1, :])
# 2. Instantiate modules
vis_mod = VisualModule()
txt_mod = TextModule()
# 3. Create the Cognitive Engine and Register Modules
engine = CognitiveAugEngine()
# Register modules with their latent dimensions (must match workspace target or be mapped)
vis_adapter = engine.register_module(
name="vision",
module=vis_mod,
latent_dim=256
)
txt_adapter = engine.register_module(
name="text",
module=txt_mod,
latent_dim=256
)
# 4. Instantiate Global Workspace
workspace = GlobalWorkspace(
latent_dim=256,
attention_type="key-query",
ignition_threshold=0.5,
workspace_slots=1 # Biological single-slot GWT bottleneck
)
# Attach workspace to the engine
engine.attach_workspace(workspace)
# 5. Run standard forward propagation
# Adapters will automatically capture latent representations via PyTorch forward hooks!
dummy_img = torch.randn(4, 3, 8, 8)
dummy_txt = torch.randint(0, 1000, (4, 10))
# Forward pass as usual
vis_out = vis_mod(dummy_img)
txt_out = txt_mod(dummy_txt)
# Perform GWT cycle: Selection & Global Broadcast!
workspace_state = engine.step()
print("Global Workspace broadcast vector shape:", workspace_state.shape)
High-Performance Modular Add-ons
For advanced cognitive systems or performance-critical setups, the package includes highly optimized, zero-copy, fully vectorized components.
1. Advanced Attention Selectors (selectors.py)
Replace the default selector with one of three high-speed subclasses inheriting from BaseSelector:
CosineSimilaritySelector: Uses vectorized matrix-basedcosine_similarityto rapidly match incoming states to top-down goals.VectorizedCrossAttentionSelector: Uses PyTorch's nativescaled_dot_product_attentionwith the Value tensor designed as an identity matrix, unlocking native FlashAttention speeds while remaining fully differentiable.EfficientGumbelSoftmaxSelector: A hard, winner-take-all routing mechanism using single-passgumbel_softmaxthat retains full differentiability.
2. Low-Overhead Salience Metrics (salience.py)
Evaluate modular activation to compute confidence and decide workspace ignition:
MagnitudeSalience: Computes L2 norms of states in a single vectorized pass usingtorch.linalg.vector_norm.EntropySalience: Computes Shannon entropy normalized to $[0, 1]$ confidence, penalizing noisy, unconfident representations.TemporalSurpriseSalience: Stateful cache that detaches gradients across iterations and uses cosine distance to measure step-to-step temporal shifts.global_pool_latent: A dimension-agnostic helper that automatically maps spatial/temporal latent tensor dimensions (e.g.[B, T, D]or[B, C, H, W]) into[B, D]vectors, avoiding sequential loops.
3. Short-Term Decay Working Memory (memory.py)
DecayWorkingMemory: Exponentially decays inactive workspace slots in-place (workspace_state.mul_(decay_rate)) and returns a blended context vector of the active new winner and the decaying trace of the past.
4. Parallel Output Router (routing.py)
CognitiveOutputRouter: Maps unified workspace representations back into dedicated output heads simultaneously using a single parallel linear projection layer and returns views using zero-copy slicing.
5. Active Dendritic Gating (gwt/dendrite.py) [Phase v0.2]
Biologically-inspired active dendritic pre-processors that dynamically modulate local feedforward pathways using GWT context:
ActiveDendriteGate: Fully vectorized module that projects context[B, context_dim]onto dendritic branches."modulatory-gain": Smooth sigmoidal contextual scaling."nmda-threshold": Sharp thresholded biological spiking. Spikes if local depolarization $\ge \text{threshold}$, otherwise zeroed out. Uses a Straight-Through Estimator (STE) for clean training gradient flow.
DendriticModuleAdapter: Special subclass ofModuleAdapterthat automatically detects model output dimensions statically (or dynamically on the first forward pass) and seamlessly appends dendritic context-gating onto the module's execution hook pipeline.get_dendritic_status: Telemetry helper scanning modules recursively to report active vs. muted pathway percentages across all dendritic gates.
For a full demo showing how to attach and run these components in an active training loop, see example_addons.py.
License
This project is licensed under the MIT License - see the LICENSE details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cognitive_aug-0.2.2.tar.gz.
File metadata
- Download URL: cognitive_aug-0.2.2.tar.gz
- Upload date:
- Size: 2.1 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
923b3b30cfe9c8eb89c01e13f6acfd8cf557de7a5428263250e3440c8a0c1089
|
|
| MD5 |
88f7ed05e1c8347e9613c09ec3905b15
|
|
| BLAKE2b-256 |
f12c145482f861f22de050539272cf279a9af401e7638cc90b87eadff3c5e83c
|
File details
Details for the file cognitive_aug-0.2.2-py3-none-any.whl.
File metadata
- Download URL: cognitive_aug-0.2.2-py3-none-any.whl
- Upload date:
- Size: 25.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7665ba747d699a7ec7d27a01e7e4aa5ed4b3a4a1746445c28d58a2fee7f5e67b
|
|
| MD5 |
b8fa5e2157ab01af998c610d64eb927a
|
|
| BLAKE2b-256 |
5dc0c90eaecac8573c2e4bd3d458b6a614d69848974ced47dea0716843291759
|