Skip to main content

Neural-Matter Network (NMN) - Advanced neural network layers with attention mechanisms

Project description

NMN Logo

โš›๏ธ NMN โ€” Neural Matter Networks

Not the neurons we want, but the neurons we need

Activation-free neural layers that learn non-linearity through geometric operations

PyPI version Downloads GitHub stars Tests Coverage Python License


๐ŸŽฏ TL;DR

NMN replaces traditional Linear + ReLU with a single geometric operation that learns non-linearity without activation functions:

# Traditional approach
y = relu(linear(x))  # dot product โ†’ activation

# NMN approach  
y = yat(x)  # geometric operation with built-in non-linearity

The Yat-Product (โตŸ) balances similarity and distance to create inherently non-linear transformationsโ€”no activations needed.


โœจ Key Features

Feature Description
๐Ÿ”ฅ Activation-Free Learn complex non-linear relationships without ReLU, sigmoid, or tanh
๐ŸŒ Multi-Framework PyTorch, TensorFlow, Keras, Flax (Linen & NNX)
๐Ÿงฎ Geometric Foundation Based on distance-similarity tradeoff, not just correlations
โœ… Cross-Framework Consistency Verified numerical equivalence across all frameworks
๐Ÿง  Complete Layer Suite Dense, Conv1D/2D/3D, ConvTranspose, Attention, RNN cells
โšก Production Ready Comprehensive tests, CI/CD, high code coverage

๐Ÿ“ The Mathematics

Yat-Product (โตŸ)

The core operation that powers NMN:

$$ โตŸ(\mathbf{w}, \mathbf{x}) = \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{|\mathbf{w} - \mathbf{x}|^2 + \epsilon} $$

๐Ÿ” Geometric Interpretation (click to expand)

Rewriting in terms of norms and angles:

$$ โตŸ(\mathbf{w}, \mathbf{x}) = \frac{|\mathbf{w}|^2 |\mathbf{x}|^2 \cos^2\theta}{|\mathbf{w}|^2 - 2\langle\mathbf{w}, \mathbf{x}\rangle + |\mathbf{x}|^2 + \epsilon} $$

Output is maximized when:

  • โœ… Vectors are aligned (small ฮธ โ†’ large cosยฒฮธ)
  • โœ… Vectors are close (small Euclidean distance)
  • โœ… Vectors have large magnitude (amplifies the signal)

This creates a fundamentally different learning dynamic:

Traditional Neuron Yat Neuron
Measures correlation only Balances similarity AND proximity
Requires activation for non-linearity Non-linearity is intrinsic
Can fire for distant but aligned vectors Penalizes distance between w and x

Yat-Convolution (โตŸ*)

The same principle applied to local patches:

$$ โตŸ^*(\mathbf{W}, \mathbf{X}) = \frac{(\sum_{i,j} w_{ij} \cdot x_{ij})^2}{\sum_{i,j}(w_{ij} - x_{ij})^2 + \epsilon} $$

Where W is the kernel and X is the input patch.


๐Ÿš€ Quick Start

Installation

pip install nmn

# Framework-specific installations
pip install "nmn[torch]"    # PyTorch
pip install "nmn[keras]"    # Keras/TensorFlow  
pip install "nmn[nnx]"      # Flax NNX (JAX)
pip install "nmn[all]"      # Everything

Basic Usage

PyTorch

import torch
from nmn.torch.nmn import YatNMN

# Replace nn.Linear + activation
layer = YatNMN(
    in_features=128,
    out_features=64,
    epsilon=1e-5
)

x = torch.randn(32, 128)
y = layer(x)  # (32, 64) โ€” non-linear output!

Keras

import keras
from nmn.keras.nmn import YatNMN

# Drop-in replacement for Dense
layer = YatNMN(
    features=64,
    epsilon=1e-5
)

x = keras.ops.zeros((32, 128))
y = layer(x)  # (32, 64)

Flax NNX

from flax import nnx
from nmn.nnx.nmn import YatNMN

layer = YatNMN(
    in_features=128,
    out_features=64,
    rngs=nnx.Rngs(0)
)

x = jax.numpy.zeros((32, 128))
y = layer(x)  # (32, 64)

TensorFlow

import tensorflow as tf
from nmn.tf.nmn import YatNMN

layer = YatNMN(features=64)

x = tf.zeros((32, 128))
y = layer(x)  # (32, 64)

๐Ÿ“ฆ Layer Support Matrix

Core Layers

Layer PyTorch TensorFlow Keras Flax NNX Flax Linen
YatNMN (Dense) โœ… โœ… โœ… โœ… โœ…
YatConv1D โœ… โœ… โœ… โœ… โœ…
YatConv2D โœ… โœ… โœ… โœ… โœ…
YatConv3D โœ… โœ… โœ… โœ… โœ…
YatConvTranspose1D โœ… โœ… โœ… โœ… โŒ
YatConvTranspose2D โœ… โœ… โœ… โœ… โŒ
YatConvTranspose3D โœ… โœ… โŒ โœ… โŒ

Advanced Layers (Flax NNX)

Layer Status Description
MultiHeadAttention โœ… Yat-based attention mechanism
YatSimpleCell โœ… Simple RNN cell
YatLSTMCell โœ… LSTM with Yat operations
YatGRUCell โœ… GRU with Yat operations
softermax โœ… Generalized softmax: $\frac{x_k^n}{\epsilon + \sum_i x_i^n}$
softer_sigmoid โœ… Smooth sigmoid variant
soft_tanh โœ… Smooth tanh variant
DropConnect โœ… Weight-level dropout regularization

๐Ÿ”ฌ Cross-Framework Consistency

All implementations are verified to produce numerically equivalent outputs given identical inputs and weights:

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              Cross-Framework Consistency Test               โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚  Framework Pair          โ”‚ Max Error    โ”‚ Status            โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚  PyTorch โ†” TensorFlow    โ”‚ < 1e-6       โ”‚ โœ… PASS           โ”‚
โ”‚  PyTorch โ†” Keras         โ”‚ < 1e-6       โ”‚ โœ… PASS           โ”‚
โ”‚  PyTorch โ†” Flax NNX      โ”‚ < 1e-6       โ”‚ โœ… PASS           โ”‚
โ”‚  PyTorch โ†” Flax Linen    โ”‚ < 1e-6       โ”‚ โœ… PASS           โ”‚
โ”‚  TensorFlow โ†” Keras      โ”‚ < 1e-7       โ”‚ โœ… PASS           โ”‚
โ”‚  Flax NNX โ†” Flax Linen   โ”‚ < 1e-7       โ”‚ โœ… PASS           โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

This demonstrates the robustness of the geometric YAT formulation across different numerical backends.


๐Ÿ“š Examples

See EXAMPLES.md for comprehensive usage guides including:

  • Framework-specific quick starts (PyTorch, Keras, TensorFlow, Flax)
  • Architecture examples (CNN, Transformer, RNN)
  • Advanced features (DropConnect, custom squashers, attention)

Quick run:

python examples/torch/yat_cifar10.py      # PyTorch CIFAR-10
python examples/keras/language_imdb.py    # Keras sentiment
python examples/nnx/language/mingpt.py    # JAX GPT

๐Ÿงช Testing

Comprehensive test suite with cross-framework validation:

# Install test dependencies
pip install "nmn[test]"

# Run all tests
pytest tests/ -v

# Run specific framework
pytest tests/test_torch/ -v
pytest tests/test_keras/ -v
pytest tests/test_nnx/ -v

# Run cross-framework consistency tests
pytest tests/integration/test_cross_framework_consistency.py -v

# With coverage
pytest tests/ --cov=nmn --cov-report=html

Test Structure

tests/
โ”œโ”€โ”€ test_torch/          # PyTorch layer tests + math validation
โ”œโ”€โ”€ test_keras/          # Keras layer tests
โ”œโ”€โ”€ test_tf/             # TensorFlow layer tests
โ”œโ”€โ”€ test_nnx/            # Flax NNX tests (attention, RNN, etc.)
โ”œโ”€โ”€ test_linen/          # Flax Linen tests
โ””โ”€โ”€ integration/
    โ”œโ”€โ”€ test_cross_framework_consistency.py  # Numerical equivalence
    โ””โ”€โ”€ test_compatibility.py                # API compatibility

๐Ÿ“š Theoretical Foundation

Based on the research papers:

Deep Learning 2.0: Artificial Neurons that Matter โ€” Reject Correlation, Embrace Orthogonality

Deep Learning 2.1: Mind and Cosmos โ€” Towards Cosmos-Inspired Interpretable Neural Networks

Why Yat-Product?

Traditional neurons compute: $y = \sigma(\mathbf{w}^\top \mathbf{x} + b)$

This has limitations:

  • Correlation-based: Only measures alignment, ignores proximity
  • Requires activation: Non-linearity is external
  • Spurious activations: Can fire strongly for distant but aligned vectors

The Yat-Product addresses these by combining:

  1. Squared dot product (similarity) in the numerator
  2. Squared distance (proximity) in the denominator
  3. Epsilon for numerical stability

The result is a neuron that responds geometrically โ€” activated when inputs are both similar AND close to weights.


๐Ÿค Contributing

We welcome contributions! See CONTRIBUTING.md for guidelines.

# Development setup
git clone https://github.com/mlnomadpy/nmn.git
cd nmn
pip install -e ".[dev,test]"

# Run tests
pytest tests/ -v

# Format code
black src/ tests/
isort src/ tests/

Areas for contribution:

  • ๐Ÿ› Bug fixes (open issues)
  • โœจ New layer types (normalization, graph, etc.)
  • ๐Ÿ“š Documentation and tutorials
  • โšก Performance optimizations
  • ๐ŸŽจ Example applications

๐Ÿ“– API Reference

Core Parameters

Parameter Type Description
in_features int Input dimension (Dense) or channels (Conv)
out_features int Output dimension or filters
kernel_size int | tuple Convolution kernel size
epsilon float Numerical stability (default: 1e-5)
use_bias bool Include bias term (default: True)
use_alpha bool Learnable output scaling (default: True)

Quick Imports

# PyTorch
from nmn.torch.nmn import YatNMN
from nmn.torch.layers import YatConv2d, YatConvTranspose2d

# Keras / TensorFlow
from nmn.keras.nmn import YatNMN
from nmn.keras.conv import YatConv2D

# Flax NNX (most complete)
from nmn.nnx.nmn import YatNMN
from nmn.nnx.yatconv import YatConv
from nmn.nnx.yatattention import MultiHeadAttention
from nmn.nnx.rnn import YatLSTMCell

๐Ÿ“‹ Full import reference โ†’ EXAMPLES.md


๐Ÿ“„ Citation

If you use NMN in your research, please cite:

@software{nmn2024,
  author = {Bouhsine, Taha},
  title = {NMN: Neural Matter Networks},
  year = {2024},
  url = {https://github.com/mlnomadpy/nmn}
}

@article{bouhsine2024dl2,
  author = {Taha Bouhsine},
  title = {Deep Learning 2.0: Artificial Neurons that Matter},
  year = {2024}
}

@article{bouhsine2025dl21,
  author = {Taha Bouhsine},
  title = {Deep Learning 2.1: Mind and Cosmos},
  year = {2025}
}

@article{bouhsine2025nomoredelulu,
  author = {Taha Bouhsine},
  title = {No More DeLuLu: A Kernel-Based Activation-Free Neural Networks},
  year = {2025}
}

๐Ÿ“ฌ Support


๐Ÿ“œ License

AGPL-3.0 โ€” Free for personal, academic, and commercial use with attribution.

If you modify and deploy on a network, you must share the source code.

For alternative licensing, contact us.


Built with โค๏ธ by azetta.ai

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nmn-0.2.6.tar.gz (8.1 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nmn-0.2.6-py3-none-any.whl (165.9 kB view details)

Uploaded Python 3

File details

Details for the file nmn-0.2.6.tar.gz.

File metadata

  • Download URL: nmn-0.2.6.tar.gz
  • Upload date:
  • Size: 8.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.1

File hashes

Hashes for nmn-0.2.6.tar.gz
Algorithm Hash digest
SHA256 f1b695f6c0083c4f5e0968033610015c9cf02815d09110a46b936832a4918bb6
MD5 efd17291b7895844ec6c47405ecd2e81
BLAKE2b-256 ff5a5e35918ac9b4b8bf5c9de2c9ab7210eabe8ec534c5508c3139645b864882

See more details on using hashes here.

File details

Details for the file nmn-0.2.6-py3-none-any.whl.

File metadata

  • Download URL: nmn-0.2.6-py3-none-any.whl
  • Upload date:
  • Size: 165.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.1

File hashes

Hashes for nmn-0.2.6-py3-none-any.whl
Algorithm Hash digest
SHA256 b85bff19df1299bbc23d999dabaf15d9faac64be1f4d2ffba8e7cf934064aac5
MD5 d8d2949ccac6a5f5c3ac05e77b3d73ab
BLAKE2b-256 d8b276620643566f61ec4c77b2aabad6cc34a1e41a3b5a8219107f63f6733c97

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page