Skip to main content

A library for building dynamic and self-pruning MLPs with Mixture of Experts and Layer Skipping.

Project description

Sparrow 🐦: A Dynamic MLP Architecture

PyPI version License: MIT

English | فارسی


English

Sparrow is a PyTorch library for building Dynamic Multi-Layer Perceptrons (MLPs). This architecture learns its own optimal structure for any given task by combining two powerful concepts:

  1. Dynamic Depth: A global router learns to activate or bypass entire hidden layers based on the input, finding the shortest computational path needed.
  2. Mixture of Experts (MoE): Each hidden layer is composed of several smaller "expert" networks. A local router within each layer selects the best expert for the current data, enabling neuron specialization and efficient computation.

This results in a highly efficient and adaptive neural network that prunes itself during training.

Key Features

  • Dynamic Depth: Automatically learns which layers to skip.
  • Mixture of Experts Layers: Activates only a subset of neurons in each layer.
  • Self-Pruning: Learns to become more computationally efficient as it masters a task.
  • Simple API: Build complex dynamic models with just a few lines of code.

Installation

We recommend creating a new virtual environment.

pip install sparrow-mlp

Quickstart

Here is a complete example of training a DynamicMLP on the Iris dataset.

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sparrow import DynamicMLP

# 1. Load and prepare data
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.LongTensor(y_train)
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.LongTensor(y_test)

# 2. Define the DynamicMLP model
model = DynamicMLP(
    input_size=4,
    output_size=3,
    hidden_dim=32,
    num_hidden_layers=2,
    num_experts=4,
    expert_hidden_size=16
)

optimizer = optim.Adam(model.parameters(), lr=0.005)
classification_criterion = nn.CrossEntropyLoss()
epochs = 300
LAYER_SPARSITY_LAMBDA = 0.01

# 3. Train the model
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train_tensor)
    classification_loss = classification_criterion(outputs, y_train_tensor)
    # Add sparsity loss to encourage layer skipping
    layer_sparsity_loss = LAYER_SPARSITY_LAMBDA * model.layer_gates_values.sum()
    total_loss = classification_loss + layer_sparsity_loss
    total_loss.backward()
    optimizer.step()

# 4. Evaluate the model
model.eval()
with torch.no_grad():
    test_outputs = model(X_test_tensor)
    _, predicted_labels = torch.max(test_outputs, 1)
    accuracy = accuracy_score(y_test_tensor.numpy(), predicted_labels.numpy())
    print(f'Final Accuracy on Iris Test Data: {accuracy * 100:.2f}%')

فارسی

Sparrow 🐦 یک کتابخانه پایتورچ برای ساخت پرسپترون‌های چندلایه دینامیک (MLP) است. این معماری ساختار بهینه خود را برای هر وظیفه با ترکیب دو مفهوم قدرتمند یاد می‌گیرد:

۱. عمق دینامیک: یک مسیریاب سراسری یاد می‌گیرد که بر اساس ورودی، کل لایه‌های پنهان را فعال یا از آنها عبور کند و کوتاه‌ترین مسیر محاسباتی مورد نیاز را پیدا کند. ۲. ترکیبی از متخصصان (MoE): هر لایه پنهان از چندین شبکه "متخصص" کوچکتر تشکیل شده است. یک مسیریاب محلی در هر لایه بهترین متخصص را برای داده فعلی انتخاب می‌کند که منجر به تخصصی شدن نورون‌ها و محاسبات بهینه می‌شود.

نتیجه این ترکیب، یک شبکه عصبی بسیار کارآمد و تطبیق‌پذیر است که در طول آموزش خود را هرس می‌کند.

قابلیت‌های کلیدی

  • عمق دینامیک: به طور خودکار یاد می‌گیرد کدام لایه‌ها را رد کند.
  • لایه‌های MoE: تنها زیرمجموعه‌ای از نورون‌ها را در هر لایه فعال می‌کند.
  • هرس خودکار: با مسلط شدن بر یک وظیفه، یاد می‌گیرد که از نظر محاسباتی بهینه‌تر شود.
  • API ساده: ساخت مدل‌های دینامیک پیچیده تنها با چند خط کد.

نصب

توصیه می‌شود یک محیط مجازی جدید ایجاد کنید.

pip install sparrow-mlp

شروع سریع

در بالا یک مثال کامل از آموزش DynamicMLP روی دیتاست گل زنبق آمده است.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparrow_mlp-3.0.2.tar.gz (6.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sparrow_mlp-3.0.2-py3-none-any.whl (7.7 kB view details)

Uploaded Python 3

File details

Details for the file sparrow_mlp-3.0.2.tar.gz.

File metadata

  • Download URL: sparrow_mlp-3.0.2.tar.gz
  • Upload date:
  • Size: 6.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for sparrow_mlp-3.0.2.tar.gz
Algorithm Hash digest
SHA256 9875ab756c321beac0378eeaf5651ea124bc5606725690af0dc14374fc731b26
MD5 b9a33d0e06d55528bf242fa7d2383e22
BLAKE2b-256 ea7d90d86131b035e72ef584756626462e2b3d0c15cc0dd5875e67daa7d581fd

See more details on using hashes here.

File details

Details for the file sparrow_mlp-3.0.2-py3-none-any.whl.

File metadata

  • Download URL: sparrow_mlp-3.0.2-py3-none-any.whl
  • Upload date:
  • Size: 7.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for sparrow_mlp-3.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 39e6421eef2bbd6ff5b233154e830627fcdb4998f6649811711c5e13ed7d50c6
MD5 44a3c1e9dc980836d8bf65f53b14d863
BLAKE2b-256 443b092f0e2c76797bdef808b3417eb9ac439eb8b431471c65e9edf9ee2a0e35

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page