Skip to main content

A library for building dynamic and self-pruning MLPs with Mixture of Experts and Layer Skipping.

Project description

Sparrow 🐦: A Dynamic MLP Architecture

PyPI version License: MIT

English | فارسی


English

Sparrow is a PyTorch library for building Dynamic Multi-Layer Perceptrons (MLPs). This architecture learns its own optimal structure for any given task by combining two powerful concepts:

  1. Dynamic Depth: A global router learns to activate or bypass entire hidden layers based on the input, finding the shortest computational path needed.
  2. Mixture of Experts (MoE): Each hidden layer is composed of several smaller "expert" networks. A local router within each layer selects the best expert for the current data, enabling neuron specialization and efficient computation.

This results in a highly efficient and adaptive neural network that prunes itself during training.

Key Features

  • Dynamic Depth: Automatically learns which layers to skip.
  • Mixture of Experts Layers: Activates only a subset of neurons in each layer.
  • Self-Pruning: Learns to become more computationally efficient as it masters a task.
  • Simple API: Build complex dynamic models with just a few lines of code.

Installation

We recommend creating a new virtual environment.

pip install sparrow-mlp

Quickstart

Here is a complete example of training a DynamicMLP on the Iris dataset.

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sparrow import DynamicMLP

# 1. Load and prepare data
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.LongTensor(y_train)
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.LongTensor(y_test)

# 2. Define the DynamicMLP model
model = DynamicMLP(
    input_size=4,
    output_size=3,
    hidden_dim=32,
    num_hidden_layers=2,
    num_experts=4,
    expert_hidden_size=16
)

optimizer = optim.Adam(model.parameters(), lr=0.005)
classification_criterion = nn.CrossEntropyLoss()
epochs = 300
LAYER_SPARSITY_LAMBDA = 0.01

# 3. Train the model
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train_tensor)
    classification_loss = classification_criterion(outputs, y_train_tensor)
    # Add sparsity loss to encourage layer skipping
    layer_sparsity_loss = LAYER_SPARSITY_LAMBDA * model.layer_gates_values.sum()
    total_loss = classification_loss + layer_sparsity_loss
    total_loss.backward()
    optimizer.step()

# 4. Evaluate the model
model.eval()
with torch.no_grad():
    test_outputs = model(X_test_tensor)
    _, predicted_labels = torch.max(test_outputs, 1)
    accuracy = accuracy_score(y_test_tensor.numpy(), predicted_labels.numpy())
    print(f'Final Accuracy on Iris Test Data: {accuracy * 100:.2f}%')

فارسی

Sparrow 🐦 یک کتابخانه پایتورچ برای ساخت پرسپترون‌های چندلایه دینامیک (MLP) است. این معماری ساختار بهینه خود را برای هر وظیفه با ترکیب دو مفهوم قدرتمند یاد می‌گیرد:

۱. عمق دینامیک: یک مسیریاب سراسری یاد می‌گیرد که بر اساس ورودی، کل لایه‌های پنهان را فعال یا از آنها عبور کند و کوتاه‌ترین مسیر محاسباتی مورد نیاز را پیدا کند. ۲. ترکیبی از متخصصان (MoE): هر لایه پنهان از چندین شبکه "متخصص" کوچکتر تشکیل شده است. یک مسیریاب محلی در هر لایه بهترین متخصص را برای داده فعلی انتخاب می‌کند که منجر به تخصصی شدن نورون‌ها و محاسبات بهینه می‌شود.

نتیجه این ترکیب، یک شبکه عصبی بسیار کارآمد و تطبیق‌پذیر است که در طول آموزش خود را هرس می‌کند.

قابلیت‌های کلیدی

  • عمق دینامیک: به طور خودکار یاد می‌گیرد کدام لایه‌ها را رد کند.
  • لایه‌های MoE: تنها زیرمجموعه‌ای از نورون‌ها را در هر لایه فعال می‌کند.
  • هرس خودکار: با مسلط شدن بر یک وظیفه، یاد می‌گیرد که از نظر محاسباتی بهینه‌تر شود.
  • API ساده: ساخت مدل‌های دینامیک پیچیده تنها با چند خط کد.

نصب

توصیه می‌شود یک محیط مجازی جدید ایجاد کنید.

pip install sparrow-mlp

شروع سریع

در بالا یک مثال کامل از آموزش DynamicMLP روی دیتاست گل زنبق آمده است.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparrow_mlp-3.0.0.tar.gz (5.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sparrow_mlp-3.0.0-py3-none-any.whl (6.4 kB view details)

Uploaded Python 3

File details

Details for the file sparrow_mlp-3.0.0.tar.gz.

File metadata

  • Download URL: sparrow_mlp-3.0.0.tar.gz
  • Upload date:
  • Size: 5.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for sparrow_mlp-3.0.0.tar.gz
Algorithm Hash digest
SHA256 c05d71d3ab76cc6fa10d107e02428ad98a00bddd17ea0c6c18b25143dc9bd1e5
MD5 0c1a17f2133c05f5fffe96755fc4db2e
BLAKE2b-256 3f84a7f8ba9192e24b921bdc10eea37cb5c6ab11e12a2146ae78b6da5ec8ec3b

See more details on using hashes here.

File details

Details for the file sparrow_mlp-3.0.0-py3-none-any.whl.

File metadata

  • Download URL: sparrow_mlp-3.0.0-py3-none-any.whl
  • Upload date:
  • Size: 6.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for sparrow_mlp-3.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 66fde2a0e4484cea875f32bd368b83df7e25945f9032ab726df8afa9b65b9f22
MD5 a94aaf742f5725947ca9393b22bc684e
BLAKE2b-256 578986dc148c9f4139c1b0f02a23b085f93ab5837b8c939acba87954a6551e50

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page