L0 regularization for sparse neural networks and intelligent sampling
Project description
L0 Regularization
A PyTorch implementation of L0 regularization based on Louizos, Welling, & Kingma (2017), with a critical bug fix to the original authors' implementation.
Why This Package?
The original L0 implementation from AMLab-Amsterdam contains a bug in test-time gate computation where the temperature parameter is incorrectly omitted. This prevents proper sparsity and degrades performance. Our implementation corrects this error:
# Original (buggy): gates never fully close
pi = sigmoid(log_alpha)
# Corrected: temperature division required for proper sparsity
pi = sigmoid(log_alpha / temperature)
This fix enables gates to achieve true 0/1 values, producing exact sparsity as intended by the L0 formulation.
Installation
pip install l0-python
For development:
git clone https://github.com/PolicyEngine/L0.git
cd L0
pip install -e .[dev]
Primary Use Case: Survey Calibration
This package was developed for PolicyEngine's survey calibration system, where we select a sparse subset of survey households while matching population targets.
import numpy as np
from scipy import sparse as sp
from l0.calibration import SparseCalibrationWeights
# Setup: Q targets, N households
Q, N = 200, 10000
M = sp.random(Q, N, density=0.3, format="csr") # Household characteristics
y = np.random.uniform(1e6, 1e8, size=Q) # Population targets
# Initialize model
model = SparseCalibrationWeights(
n_features=N,
beta=0.35,
gamma=-0.1,
zeta=1.1,
init_keep_prob=0.5, # Start with 50% active probability
init_weights=1.0, # Or pass array of initial weights
log_weight_jitter_sd=0.05,
device="cuda", # GPU acceleration
)
# Train with L0+L2 regularization
model.fit(
M=M,
y=y,
lambda_l0=1e-6, # Controls sparsity level
lambda_l2=1e-8, # Prevents weight explosion
lr=0.15,
epochs=2000,
loss_type="relative", # Scale-invariant loss
verbose=True,
verbose_freq=100,
)
# Get results
active = model.get_active_weights()
print(f"Selected {active['count']} of {N} households ({100*active['count']/N:.1f}%)")
print(f"Sparsity: {model.get_sparsity():.1%}")
# Predict calibrated totals
y_pred = model.predict(M)
Key Features for Calibration
- Non-negative weights: All weights constrained to be positive via log-space parameterization
- Sparse solutions: L0 penalty directly minimizes the count of active weights
- Relative loss: Scale-invariant loss for targets spanning orders of magnitude
- Group-wise averaging: Balance loss contributions across target groups with different cardinalities
- GPU support: CUDA acceleration for large-scale problems
Neural Network Sparsification
The package also supports traditional neural network pruning:
import torch
from l0 import L0Linear, compute_l0l2_penalty, TemperatureScheduler, update_temperatures
class SparseModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = L0Linear(784, 256, init_sparsity=0.5)
self.fc2 = L0Linear(256, 10, init_sparsity=0.7)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
model = SparseModel()
optimizer = torch.optim.Adam(model.parameters())
scheduler = TemperatureScheduler(initial_temp=1.0, final_temp=0.1)
for epoch in range(100):
temp = scheduler.get_temperature(epoch)
update_temperatures(model, temp)
output = model(input_data)
ce_loss = criterion(output, target)
penalty = compute_l0l2_penalty(model, l0_lambda=1e-3, l2_lambda=1e-4)
loss = ce_loss + penalty
optimizer.zero_grad()
loss.backward()
optimizer.step()
Available Layers
L0Linear: Fully connected layer with L0 gates on weightsL0Conv2d: 2D convolution with channel-wise L0 gatesL0DepthwiseConv2d: Depthwise convolution with L0 gatesSparseMLP: Multi-layer perceptron with built-in L0 regularization
Intelligent Sampling Gates
Standalone gates for sample/feature selection:
from l0 import SampleGate, FeatureGate, HybridGate
# Select samples via learned gates
gate = SampleGate(n_samples=10000, target_samples=1000)
selected_data, indices = gate.select_samples(data)
# Select features
gate = FeatureGate(n_features=1000, max_features=50)
selected_data, indices = gate.select_features(data)
# Hybrid: combine L0 selection with random sampling
hybrid = HybridGate(
n_items=10000,
l0_fraction=0.25, # 25% via learned L0 gates
random_fraction=0.75 # 75% random for coverage
)
selected, indices, types = hybrid.select(data)
How L0 Regularization Works
Unlike post-hoc pruning (setting small weights to zero), L0 regularization integrates sparsity into the training objective:
- Stochastic gates: Each weight has a learned gate parameter controlling activation probability
- Hard Concrete distribution: Enables differentiable 0/1 gate values during training
- Expected L0 penalty: Minimizes the expected number of active gates
- Temperature annealing: Gradually sharpens gates from soft to hard decisions
The result: the network learns which weights should be zero as part of optimization, not as a post-processing step.
Testing
pytest tests/ -v --cov=l0
Citation
If you use this package, please cite the original paper:
@article{louizos2017learning,
title={Learning Sparse Neural Networks through L0 Regularization},
author={Louizos, Christos and Welling, Max and Kingma, Diederik P},
journal={arXiv preprint arXiv:1712.01312},
year={2017}
}
License
MIT License - see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file l0_python-0.4.2.tar.gz.
File metadata
- Download URL: l0_python-0.4.2.tar.gz
- Upload date:
- Size: 37.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b92364edf1c2aa41433413df4ffec1de5a7d34e27f9271d4320382a01443322c
|
|
| MD5 |
1ea9720760b4c0988513f99af0b5cd3e
|
|
| BLAKE2b-256 |
b1b243381b09d35da0be297dca081c9121102794c768225c97542d922c7ca08a
|
File details
Details for the file l0_python-0.4.2-py3-none-any.whl.
File metadata
- Download URL: l0_python-0.4.2-py3-none-any.whl
- Upload date:
- Size: 24.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
89b0cbd3381259668f1a3e8b02f5f2affee682742a66d739e491b2df409fa13b
|
|
| MD5 |
1438f10238b015802c7576f9b2c95f77
|
|
| BLAKE2b-256 |
2c4d15e064365d5b0953ce80243756c5a0e0d80a4bd3f231075d6f0d26f1e386
|