PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
Project description
PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
PILOT is an online adaptive optimizer that adjusts its update behavior during training using gradient-direction agreement as a signal of local optimization stability.
Overview
Most optimizers use a fixed update structure throughout training — a static balance between momentum, normalization, and sign-based updates that cannot respond to how the loss landscape evolves.
PILOT introduces a learnable policy that continuously modulates three core update primitives:
- Momentum reliance — how much to rely on accumulated gradient history vs. the current gradient
- Variance-normalization strength — how aggressively to apply adaptive scaling
- Sign-based behavior — how much to compress gradient magnitudes toward ±1
The policy is conditioned on a smoothed gradient-direction agreement signal, which serves as a compact online descriptor of local update consistency. It is updated online during training using a one-step meta-gradient estimate — no offline search, no meta-training phase, no second-order estimation.
PILOT follows a distinct trajectory through the loss surface and converges to a lower-loss region compared to Adam, AdamW, Lion, and Sophia.
Key Results
CNN Architecture
| Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
|---|---|---|---|---|
| FashionMNIST | Adam | 93.28 | 0.1957 | 0.0033 |
| FashionMNIST | AdamW | 93.22 | 0.1944 | 0.0034 |
| FashionMNIST | Lion | 92.91 | 0.2091 | 0.0041 |
| FashionMNIST | AdaBelief | 93.66 | 0.1822 | 0.0046 |
| FashionMNIST | PILOT (Ours) | 94.13 | 0.1719 | 0.0045 |
| CIFAR-10 | Adam | 79.91 | 0.5794 | 0.0103 |
| CIFAR-10 | Lion | 80.87 | 0.5487 | 0.0105 |
| CIFAR-10 | PILOT (Ours) | 81.94 | 0.5302 | 0.0073 |
ResNet-18 Architecture
| Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
|---|---|---|---|---|
| FashionMNIST | AdaBelief | 95.33 | 0.1711 | 0.0056 |
| FashionMNIST | PILOT (Ours) | 95.71 | 0.2690 | 0.0030 |
| CIFAR-10 | Adam | 93.18 | 0.2140 | 0.0073 |
| CIFAR-10 | AdamW | 92.90 | 0.2514 | 0.0066 |
| CIFAR-10 | PILOT (Ours) | 93.42 | 0.2496 | 0.0001 |
Method
Gradient-Direction Agreement
At each step, PILOT computes the cosine similarity between successive gradients:
$$r_t = \frac{g_t^\top g_{t-1}}{|g_t|2 , |g{t-1}|_2 + \epsilon}$$
This is smoothed via an exponential moving average:
$$\rho_t = \gamma \rho_{t-1} + (1 - \gamma) r_t$$
Positive values indicate stable, aligned gradients. Values near zero indicate noise. Negative values indicate directional disagreement.
Learnable Policy
The smoothed signal $\rho_t$ is fed through polynomial functions followed by sigmoid activations to produce three scalar control variables:
$$p_{m,t} = \sigma(f(\rho_t; \phi_m)), \quad p_{v,t} = \tfrac{1}{2}\sigma(f(\rho_t; \phi_v)), \quad p_{s,t} = \sigma(f(\rho_t; \phi_s))$$
The total number of learnable policy parameters is $3(d+1)$, where $d$ is the polynomial degree.
Update Rule
$$\theta_{t+1} = \theta_t - \eta \frac{(|n_t| + \epsilon_n)^{1 - p_{s,t}} \odot \text{sign}(n_t)}{\hat{v}t^{,p{v,t}} + \epsilon}$$
where $n_t = p_{m,t} \hat{m}t + (1 - p{m,t}) g_t$ is the policy-controlled blend of momentum and current gradient.
This formulation recovers Adam ($p_m=1, p_v=0.5, p_s=0$) and sign-based updates ($p_s=1, p_v=0$) as special cases.
Installation
pip install pilot-optimizer
Or install from source:
git clone https://github.com/SattamAltwaim/PILOT.git
cd PILOT
pip install -e .
Usage
from pilot import PILOT
optimizer = PILOT(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
weight_decay=1e-4,
gamma=0.95, # smoothing coefficient for agreement signal
lr_phi=0.01, # policy learning rate
degree=2 # polynomial degree
)
for batch in dataloader:
loss = criterion(model(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Hyperparameters
| Parameter | Description | Typical Range |
|---|---|---|
lr |
Model learning rate | 1e-4 – 1e-3 |
betas |
Moment coefficients | (0.9, 0.999) |
gamma |
Agreement signal smoothing | 0.85 – 0.99 |
lr_phi |
Policy learning rate | 5e-4 – 5e-2 |
degree |
Polynomial degree | 1 – 4 |
Configuration-Specific Selections
| Dataset | Architecture | γ | η_φ | Degree |
|---|---|---|---|---|
| CIFAR-10 | CNN | 0.882 | 0.00312 | 1 |
| CIFAR-10 | ResNet-18 | 0.950 | 0.00500 | 2 |
| FashionMNIST | CNN | 0.950 | 0.01000 | 2 |
| FashionMNIST | ResNet-18 | 0.957 | 0.00273 | 3 |
Experiments
Experiments use 30 epochs, cross-entropy loss, cosine annealing LR schedule, batch size 128, and AMP. ResNet-18 configurations include a 3-epoch linear warmup.
# CNN on CIFAR-10
python train.py --dataset cifar10 --arch cnn --optimizer pilot
# ResNet-18 on FashionMNIST
python train.py --dataset fashionmnist --arch resnet18 --optimizer pilot
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pilot_optimizer-0.1.0.tar.gz.
File metadata
- Download URL: pilot_optimizer-0.1.0.tar.gz
- Upload date:
- Size: 15.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c1e682513d2dc795495ec172d03d133a39aeb6738f7a3f3e59f9bc71dfe92fa6
|
|
| MD5 |
60052032561d9552bfe12e85d5be0a50
|
|
| BLAKE2b-256 |
6c3414dde7e056df72af9cef012d8ab69a2dcc6d4be63c37091bde54bd14ce25
|
File details
Details for the file pilot_optimizer-0.1.0-py3-none-any.whl.
File metadata
- Download URL: pilot_optimizer-0.1.0-py3-none-any.whl
- Upload date:
- Size: 10.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
49ef76b437397165b5b18a9105b7b2dc3d30adbdfa96dfa89ba358692e13d70b
|
|
| MD5 |
8a2887a18f2630c058674c27dfa889ea
|
|
| BLAKE2b-256 |
364e0ac5d5cb510e984e398115be5e70704528009a16668ac802cdddc4d9108b
|