Skip to main content

A library for Evolution Strategy trainers in PyTorch, including the EGGROLL algorithm

Project description

Eggroll Trainer

Python 3.12+ License: MIT Documentation

A library for Evolution Strategy (ES) trainers in PyTorch, including the EGGROLL algorithm.

Documentation

📚 Full Documentation - Complete guide with examples, API reference, and research.

Installation

# Using pip
pip install eggroll-trainer

# Or using uv
uv add eggroll-trainer

# For examples with plotting:
pip install "eggroll-trainer[examples]"
# or
uv add eggroll-trainer --extra examples

For development/contributing, see CONTRIBUTING.md.

What is EGGROLL?

EGGROLL (Evolution Guided General Optimization via Low-rank Learning) is a novel ES algorithm that provides a hundredfold increase in training speed over naïve evolution strategies by using low-rank perturbations instead of full-rank ones.

Key innovation: For matrix parameters W ∈ R^(m×n), EGGROLL samples low-rank matrices A ∈ R^(m×r), B ∈ R^(n×r) where r << min(m,n), forming perturbations A @ B.T. This reduces:

  • Memory: O(mn) → O(r(m+n))
  • Computation: O(mn) → O(r(m+n))

Yet still achieves high-rank updates through population averaging!

Based on: Evolution Strategies at the Hyperscale

Usage

EGGROLL Trainer (Recommended)

import torch
import torch.nn as nn
from eggroll_trainer import EGGROLLTrainer

# Define a model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)  # Matrix: uses LoRA updates
        self.fc2 = nn.Linear(20, 1)   # Matrix: uses LoRA updates
    
    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x)))

# Define fitness function (higher is better)
def fitness_fn(model):
    # Your evaluation logic
    return torch.randn(1).item()

# Create EGGROLL trainer
model = SimpleModel()
trainer = EGGROLLTrainer(
    model.parameters(),
    model=model,
    fitness_fn=fitness_fn,
    population_size=256,      # Large populations are efficient!
    learning_rate=0.01,
    sigma=0.1,
    rank=1,                   # Low-rank rank (1 is often sufficient)
    noise_reuse=0,            # 0 = no reuse, 2 = antithetic sampling
    group_size=0,             # 0 = global normalization
    freeze_nonlora=False,     # If True, only update matrix params
    seed=42,
)

# Train
trainer.train(num_generations=100)

Base ES Trainer

For custom ES algorithms, subclass ESTrainer:

from eggroll_trainer import ESTrainer
import torch

class MyESTrainer(ESTrainer):
    def sample_perturbations(self, population_size):
        param_dim = self.current_params.shape[0]
        return torch.randn(population_size, param_dim, device=self.device)
    
    def compute_update(self, perturbations, fitnesses):
        weights = (fitnesses - fitnesses.mean()) / fitnesses.std()
        return (weights[:, None] * perturbations).mean(dim=0)

Architecture

EGGROLLTrainer

The EGGROLLTrainer implements the actual EGGROLL algorithm:

  • Low-rank perturbations for 2D parameters (matrices): Uses A @ B.T where A ∈ R^(m×r), B ∈ R^(n×r)
  • Full-rank perturbations for 1D/3D+ parameters (biases, etc.)
  • Per-layer updates: Handles each parameter tensor independently
  • Fitness normalization: Supports global or group-based normalization
  • Noise reuse: Optional antithetic sampling for efficiency

ESTrainer (Base Class)

The base ESTrainer class provides:

  • Parameter flattening/unflattening utilities
  • Training loop framework
  • Fitness evaluation infrastructure
  • History tracking

Subclasses implement:

  • sample_perturbations(): How to sample perturbations
  • compute_update(): How to compute parameter updates from fitnesses

Examples

See the examples/ directory:

  • basic_example.pySTART HERE - Side-by-side comparison of VanillaESTrainer and EGGROLLTrainer
  • mnist_comparison.py - Full EGGROLL vs SGD comparison on MNIST with plots
  • run_all_comparisons.py - Multi-architecture comparison (CNN, Transformer, MLP)
  • comparison_framework.py - Reusable framework for comparing optimizers
  • models.py - Shared model architectures
  • utils.py - Shared utility functions
  • Test suites: test_comprehensive.py, test_eggroll.py, test_mnist_eggroll.py

3D Reinforcement Learning Examples

Train agents in 3D MuJoCo environments:

Ant Locomotion HalfCheetah Running Humanoid Walking Hopper Locomotion Walker2d Swimmer Reacher

See examples/README.md and examples/animals_3d/README.md for detailed documentation.

Key Features

  • EGGROLL algorithm - Low-rank perturbations for massive speedup
  • PyTorch native - Works with any PyTorch model
  • Flexible - Supports custom ES algorithms via subclassing
  • Efficient - Optimized for large population sizes
  • Well-tested - Comprehensive test suite included

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

eggroll_trainer-0.1.0.tar.gz (95.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

eggroll_trainer-0.1.0-py3-none-any.whl (13.7 kB view details)

Uploaded Python 3

File details

Details for the file eggroll_trainer-0.1.0.tar.gz.

File metadata

  • Download URL: eggroll_trainer-0.1.0.tar.gz
  • Upload date:
  • Size: 95.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for eggroll_trainer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 3f2a72a3184c0d1a27c4f8de00a006ef793b158158042d7a0bccd79037f34e19
MD5 dc15dcd7b22b316e7061be19a74f2de9
BLAKE2b-256 76f6e9a43f510c28da29691a5a8c0bfc8af4d539b216bb1af9fa67c12ae71883

See more details on using hashes here.

File details

Details for the file eggroll_trainer-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for eggroll_trainer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4b1db57d82d9f56408bb972b4c2bf68177322e098717f72e2a3d1e8179e04bd8
MD5 21244dd951d6c85d8b84bc0adf551e13
BLAKE2b-256 625334fc5478dbae924c46ce12eb31594dd27d5713bfec45c0d4eaa0ad688100

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page