A library for Evolution Strategy trainers in PyTorch, including the EGGROLL algorithm
Project description
Eggroll Trainer
A library for Evolution Strategy (ES) trainers in PyTorch, including the EGGROLL algorithm.
Documentation
📚 Full Documentation - Complete guide with examples, API reference, and research.
Installation
# Using pip
pip install eggroll-trainer
# Or using uv
uv add eggroll-trainer
# For examples with plotting:
pip install "eggroll-trainer[examples]"
# or
uv add eggroll-trainer --extra examples
For development/contributing, see CONTRIBUTING.md.
What is EGGROLL?
EGGROLL (Evolution Guided General Optimization via Low-rank Learning) is a novel ES algorithm that provides a hundredfold increase in training speed over naïve evolution strategies by using low-rank perturbations instead of full-rank ones.
Key innovation: For matrix parameters W ∈ R^(m×n), EGGROLL samples low-rank matrices A ∈ R^(m×r), B ∈ R^(n×r) where r << min(m,n), forming perturbations A @ B.T. This reduces:
- Memory: O(mn) → O(r(m+n))
- Computation: O(mn) → O(r(m+n))
Yet still achieves high-rank updates through population averaging!
Based on: Evolution Strategies at the Hyperscale
Usage
EGGROLL Trainer (Recommended)
import torch
import torch.nn as nn
from eggroll_trainer import EGGROLLTrainer
# Define a model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20) # Matrix: uses LoRA updates
self.fc2 = nn.Linear(20, 1) # Matrix: uses LoRA updates
def forward(self, x):
return self.fc2(torch.relu(self.fc1(x)))
# Define fitness function (higher is better)
def fitness_fn(model):
# Your evaluation logic
return torch.randn(1).item()
# Create EGGROLL trainer
model = SimpleModel()
trainer = EGGROLLTrainer(
model.parameters(),
model=model,
fitness_fn=fitness_fn,
population_size=256, # Large populations are efficient!
learning_rate=0.01,
sigma=0.1,
rank=1, # Low-rank rank (1 is often sufficient)
noise_reuse=0, # 0 = no reuse, 2 = antithetic sampling
group_size=0, # 0 = global normalization
freeze_nonlora=False, # If True, only update matrix params
seed=42,
)
# Train
trainer.train(num_generations=100)
Base ES Trainer
For custom ES algorithms, subclass ESTrainer:
from eggroll_trainer import ESTrainer
import torch
class MyESTrainer(ESTrainer):
def sample_perturbations(self, population_size):
param_dim = self.current_params.shape[0]
return torch.randn(population_size, param_dim, device=self.device)
def compute_update(self, perturbations, fitnesses):
weights = (fitnesses - fitnesses.mean()) / fitnesses.std()
return (weights[:, None] * perturbations).mean(dim=0)
Architecture
EGGROLLTrainer
The EGGROLLTrainer implements the actual EGGROLL algorithm:
- Low-rank perturbations for 2D parameters (matrices): Uses A @ B.T where A ∈ R^(m×r), B ∈ R^(n×r)
- Full-rank perturbations for 1D/3D+ parameters (biases, etc.)
- Per-layer updates: Handles each parameter tensor independently
- Fitness normalization: Supports global or group-based normalization
- Noise reuse: Optional antithetic sampling for efficiency
ESTrainer (Base Class)
The base ESTrainer class provides:
- Parameter flattening/unflattening utilities
- Training loop framework
- Fitness evaluation infrastructure
- History tracking
Subclasses implement:
sample_perturbations(): How to sample perturbationscompute_update(): How to compute parameter updates from fitnesses
Examples
See the examples/ directory:
basic_example.py⭐ START HERE - Side-by-side comparison of VanillaESTrainer and EGGROLLTrainermnist_comparison.py- Full EGGROLL vs SGD comparison on MNIST with plotsrun_all_comparisons.py- Multi-architecture comparison (CNN, Transformer, MLP)comparison_framework.py- Reusable framework for comparing optimizersmodels.py- Shared model architecturesutils.py- Shared utility functions- Test suites:
test_comprehensive.py,test_eggroll.py,test_mnist_eggroll.py
3D Reinforcement Learning Examples
Train agents in 3D MuJoCo environments:
See examples/README.md and examples/animals_3d/README.md for detailed documentation.
Key Features
- ✅ EGGROLL algorithm - Low-rank perturbations for massive speedup
- ✅ PyTorch native - Works with any PyTorch model
- ✅ Flexible - Supports custom ES algorithms via subclassing
- ✅ Efficient - Optimized for large population sizes
- ✅ Well-tested - Comprehensive test suite included
References
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file eggroll_trainer-0.1.0.tar.gz.
File metadata
- Download URL: eggroll_trainer-0.1.0.tar.gz
- Upload date:
- Size: 95.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3f2a72a3184c0d1a27c4f8de00a006ef793b158158042d7a0bccd79037f34e19
|
|
| MD5 |
dc15dcd7b22b316e7061be19a74f2de9
|
|
| BLAKE2b-256 |
76f6e9a43f510c28da29691a5a8c0bfc8af4d539b216bb1af9fa67c12ae71883
|
File details
Details for the file eggroll_trainer-0.1.0-py3-none-any.whl.
File metadata
- Download URL: eggroll_trainer-0.1.0-py3-none-any.whl
- Upload date:
- Size: 13.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4b1db57d82d9f56408bb972b4c2bf68177322e098717f72e2a3d1e8179e04bd8
|
|
| MD5 |
21244dd951d6c85d8b84bc0adf551e13
|
|
| BLAKE2b-256 |
625334fc5478dbae924c46ce12eb31594dd27d5713bfec45c0d4eaa0ad688100
|