Skip to main content

L0 regularization for sparse neural networks and intelligent sampling

Project description

L0 Regularization

A PyTorch implementation of L0 regularization for neural network sparsification and intelligent sampling, based on Louizos, Welling, & Kingma (2017).

Features

  • Hard Concrete Distribution: Differentiable approximation of L0 norm
  • Sparse Neural Network Layers: L0Linear, L0Conv2d with automatic pruning
  • Intelligent Sampling: Sample/feature selection gates for calibration
  • L0L2 Combined Penalty: Recommended approach to prevent overfitting
  • Temperature Scheduling: Annealing for improved convergence
  • TDD Development: Comprehensive test coverage

Installation

pip install l0

For development:

git clone https://github.com/PolicyEngine/L0.git
cd L0
pip install -e .[dev]

Quick Start

Neural Network Sparsification

import torch
from l0 import L0Linear, compute_l0l2_penalty, TemperatureScheduler, update_temperatures

# Create a sparse model
class SparseModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = L0Linear(784, 256, init_sparsity=0.5)
        self.fc2 = L0Linear(256, 10, init_sparsity=0.7)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = SparseModel()
optimizer = torch.optim.Adam(model.parameters())
scheduler = TemperatureScheduler(initial_temp=1.0, final_temp=0.1)

# Training loop
for epoch in range(100):
    # Update temperature
    temp = scheduler.get_temperature(epoch)
    update_temperatures(model, temp)
    
    # Forward pass
    output = model(input_data)
    ce_loss = criterion(output, target)
    
    # Add L0L2 penalty
    penalty = compute_l0l2_penalty(model, l0_lambda=1e-3, l2_lambda=1e-4)
    loss = ce_loss + penalty
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Intelligent Sample Selection

from l0 import SampleGate, HybridGate

# Pure L0 selection
gate = SampleGate(n_samples=10000, target_samples=1000)
selected_data, indices = gate.select_samples(data)

# Hybrid selection (25% L0, 75% random)
hybrid = HybridGate(
    n_items=10000,
    l0_fraction=0.25,
    random_fraction=0.75,
    target_items=1000
)
selected, indices, types = hybrid.select(data)

Feature Selection

from l0 import FeatureGate

# Select top features
gate = FeatureGate(n_features=1000, max_features=50)
selected_data, feature_indices = gate.select_features(data)

# Get feature importance
importance = gate.get_feature_importance()

Integration with PolicyEngine

This package is designed to work with PolicyEngine's calibration system:

# In policyengine-us-data or similar
from l0 import HardConcrete

# Use for household selection in CPS calibration
gates = HardConcrete(
    len(household_weights),
    temperature=0.25,
    init_mean=0.999  # Start with most households
)

# Apply gates during reweighting
masked_weights = weights * gates()

Documentation

Full documentation available at: https://policyengine.github.io/L0/

Testing

Run tests with:

pytest tests/ -v --cov=l0

Acknowledgments

This implementation is inspired by and builds upon the original L0 regularization code by AMLab Amsterdam, which accompanied the paper by Louizos et al. (2018).

Citation

If you use this package, please cite:

@article{louizos2017learning,
  title={Learning Sparse Neural Networks through L0 Regularization},
  author={Louizos, Christos and Welling, Max and Kingma, Diederik P},
  journal={arXiv preprint arXiv:1712.01312},
  year={2017}
}

License

MIT License - see LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

l0_python-0.1.1.tar.gz (26.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

l0_python-0.1.1-py3-none-any.whl (18.5 kB view details)

Uploaded Python 3

File details

Details for the file l0_python-0.1.1.tar.gz.

File metadata

  • Download URL: l0_python-0.1.1.tar.gz
  • Upload date:
  • Size: 26.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for l0_python-0.1.1.tar.gz
Algorithm Hash digest
SHA256 e6e81b023f7d7dbb5431cc06853de4e8ca5179a1d0c5b6bf9456b224291cae20
MD5 a9535e02d25c17f5bf45b255262a9a72
BLAKE2b-256 fb7972f0ad14b2dd128dab1476e164bd7496b3e405bf3592aca44f3665ea1a8b

See more details on using hashes here.

File details

Details for the file l0_python-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: l0_python-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 18.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for l0_python-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9b54079b3447a55ea79f12e20770dbfd25b882e7066a5c3464329f44d7831bb5
MD5 d3af068ef7bde2fd960fbc9fdcbe33b4
BLAKE2b-256 9365dfad68a9e210072c3617e86b0f99a452a270a68cd9edae4eafbb295c6a8d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page