A library for building dynamic and self-pruning MLPs with Mixture of Experts and Layer Skipping.
Project description
Sparrow 🐦: A Dynamic MLP Architecture
English
Sparrow is a PyTorch library for building Dynamic Multi-Layer Perceptrons (MLPs). This architecture learns its own optimal structure for any given task by combining two powerful concepts:
- Dynamic Depth: A global router learns to activate or bypass entire hidden layers based on the input, finding the shortest computational path needed.
- Mixture of Experts (MoE): Each hidden layer is composed of several smaller "expert" networks. A local router within each layer selects the best expert for the current data, enabling neuron specialization and efficient computation.
This results in a highly efficient and adaptive neural network that prunes itself during training.
Key Features
- Dynamic Depth: Automatically learns which layers to skip.
- Mixture of Experts Layers: Activates only a subset of neurons in each layer.
- Self-Pruning: Learns to become more computationally efficient as it masters a task.
- Simple API: Build complex dynamic models with just a few lines of code.
Installation
We recommend creating a new virtual environment.
pip install sparrow-mlp
Quickstart
Here is a complete example of training a DynamicMLP on the Iris dataset.
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sparrow import DynamicMLP
# 1. Load and prepare data
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.LongTensor(y_train)
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.LongTensor(y_test)
# 2. Define the DynamicMLP model
model = DynamicMLP(
input_size=4,
output_size=3,
hidden_dim=32,
num_hidden_layers=2,
num_experts=4,
expert_hidden_size=16
)
optimizer = optim.Adam(model.parameters(), lr=0.005)
classification_criterion = nn.CrossEntropyLoss()
epochs = 300
LAYER_SPARSITY_LAMBDA = 0.01
# 3. Train the model
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(X_train_tensor)
classification_loss = classification_criterion(outputs, y_train_tensor)
# Add sparsity loss to encourage layer skipping
layer_sparsity_loss = LAYER_SPARSITY_LAMBDA * model.layer_gates_values.sum()
total_loss = classification_loss + layer_sparsity_loss
total_loss.backward()
optimizer.step()
# 4. Evaluate the model
model.eval()
with torch.no_grad():
test_outputs = model(X_test_tensor)
_, predicted_labels = torch.max(test_outputs, 1)
accuracy = accuracy_score(y_test_tensor.numpy(), predicted_labels.numpy())
print(f'Final Accuracy on Iris Test Data: {accuracy * 100:.2f}%')
فارسی
Sparrow 🐦 یک کتابخانه پایتورچ برای ساخت پرسپترونهای چندلایه دینامیک (MLP) است. این معماری ساختار بهینه خود را برای هر وظیفه با ترکیب دو مفهوم قدرتمند یاد میگیرد:
۱. عمق دینامیک: یک مسیریاب سراسری یاد میگیرد که بر اساس ورودی، کل لایههای پنهان را فعال یا از آنها عبور کند و کوتاهترین مسیر محاسباتی مورد نیاز را پیدا کند. ۲. ترکیبی از متخصصان (MoE): هر لایه پنهان از چندین شبکه "متخصص" کوچکتر تشکیل شده است. یک مسیریاب محلی در هر لایه بهترین متخصص را برای داده فعلی انتخاب میکند که منجر به تخصصی شدن نورونها و محاسبات بهینه میشود.
نتیجه این ترکیب، یک شبکه عصبی بسیار کارآمد و تطبیقپذیر است که در طول آموزش خود را هرس میکند.
قابلیتهای کلیدی
- عمق دینامیک: به طور خودکار یاد میگیرد کدام لایهها را رد کند.
- لایههای MoE: تنها زیرمجموعهای از نورونها را در هر لایه فعال میکند.
- هرس خودکار: با مسلط شدن بر یک وظیفه، یاد میگیرد که از نظر محاسباتی بهینهتر شود.
- API ساده: ساخت مدلهای دینامیک پیچیده تنها با چند خط کد.
نصب
توصیه میشود یک محیط مجازی جدید ایجاد کنید.
pip install sparrow-mlp
شروع سریع
در بالا یک مثال کامل از آموزش DynamicMLP روی دیتاست گل زنبق آمده است.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sparrow_mlp-3.0.2.tar.gz.
File metadata
- Download URL: sparrow_mlp-3.0.2.tar.gz
- Upload date:
- Size: 6.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9875ab756c321beac0378eeaf5651ea124bc5606725690af0dc14374fc731b26
|
|
| MD5 |
b9a33d0e06d55528bf242fa7d2383e22
|
|
| BLAKE2b-256 |
ea7d90d86131b035e72ef584756626462e2b3d0c15cc0dd5875e67daa7d581fd
|
File details
Details for the file sparrow_mlp-3.0.2-py3-none-any.whl.
File metadata
- Download URL: sparrow_mlp-3.0.2-py3-none-any.whl
- Upload date:
- Size: 7.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
39e6421eef2bbd6ff5b233154e830627fcdb4998f6649811711c5e13ed7d50c6
|
|
| MD5 |
44a3c1e9dc980836d8bf65f53b14d863
|
|
| BLAKE2b-256 |
443b092f0e2c76797bdef808b3417eb9ac439eb8b431471c65e9edf9ee2a0e35
|