Skip to main content

PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training

Project description

PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training

PILOT is an online adaptive optimizer that adjusts its update behavior during training using gradient-direction agreement as a signal of local optimization stability.


Overview

Most optimizers use a fixed update structure throughout training — a static balance between momentum, normalization, and sign-based updates that cannot respond to how the loss landscape evolves.

PILOT introduces a learnable policy that continuously modulates three core update primitives:

  • Momentum reliance — how much to rely on accumulated gradient history vs. the current gradient
  • Variance-normalization strength — how aggressively to apply adaptive scaling
  • Sign-based behavior — how much to compress gradient magnitudes toward ±1

The policy is conditioned on a smoothed gradient-direction agreement signal, which serves as a compact online descriptor of local update consistency. It is updated online during training using a one-step meta-gradient estimate — no offline search, no meta-training phase, no second-order estimation.

Loss Landscape — CIFAR-10 / SmallCNN PILOT follows a distinct trajectory through the loss surface and converges to a lower-loss region compared to Adam, AdamW, Lion, and Sophia.


Key Results

CNN Architecture

Dataset Optimizer Accuracy (%) ↑ Val Loss ↓ Loss Var. ↓
FashionMNIST Adam 93.28 0.1957 0.0033
FashionMNIST AdamW 93.22 0.1944 0.0034
FashionMNIST Lion 92.91 0.2091 0.0041
FashionMNIST AdaBelief 93.66 0.1822 0.0046
FashionMNIST PILOT (Ours) 94.13 0.1719 0.0045
CIFAR-10 Adam 79.91 0.5794 0.0103
CIFAR-10 Lion 80.87 0.5487 0.0105
CIFAR-10 PILOT (Ours) 81.94 0.5302 0.0073

ResNet-18 Architecture

Dataset Optimizer Accuracy (%) ↑ Val Loss ↓ Loss Var. ↓
FashionMNIST AdaBelief 95.33 0.1711 0.0056
FashionMNIST PILOT (Ours) 95.71 0.2690 0.0030
CIFAR-10 Adam 93.18 0.2140 0.0073
CIFAR-10 AdamW 92.90 0.2514 0.0066
CIFAR-10 PILOT (Ours) 93.42 0.2496 0.0001

Method

Gradient-Direction Agreement

At each step, PILOT computes the cosine similarity between successive gradients:

$$r_t = \frac{g_t^\top g_{t-1}}{|g_t|2 , |g{t-1}|_2 + \epsilon}$$

This is smoothed via an exponential moving average:

$$\rho_t = \gamma \rho_{t-1} + (1 - \gamma) r_t$$

Positive values indicate stable, aligned gradients. Values near zero indicate noise. Negative values indicate directional disagreement.

Learnable Policy

The smoothed signal $\rho_t$ is fed through polynomial functions followed by sigmoid activations to produce three scalar control variables:

$$p_{m,t} = \sigma(f(\rho_t; \phi_m)), \quad p_{v,t} = \tfrac{1}{2}\sigma(f(\rho_t; \phi_v)), \quad p_{s,t} = \sigma(f(\rho_t; \phi_s))$$

The total number of learnable policy parameters is $3(d+1)$, where $d$ is the polynomial degree.

Update Rule

$$\theta_{t+1} = \theta_t - \eta \frac{(|n_t| + \epsilon_n)^{1 - p_{s,t}} \odot \text{sign}(n_t)}{\hat{v}t^{,p{v,t}} + \epsilon}$$

where $n_t = p_{m,t} \hat{m}t + (1 - p{m,t}) g_t$ is the policy-controlled blend of momentum and current gradient.

This formulation recovers Adam ($p_m=1, p_v=0.5, p_s=0$) and sign-based updates ($p_s=1, p_v=0$) as special cases.


Installation

pip install pilot-optimizer

Or install from source:

git clone https://github.com/SattamAltwaim/PILOT.git
cd PILOT
pip install -e .

Usage

from pilot import PILOT

optimizer = PILOT(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    weight_decay=1e-4,
    gamma=0.95,        # smoothing coefficient for agreement signal
    lr_phi=0.01,       # policy learning rate
    degree=2           # polynomial degree
)

for batch in dataloader:
    loss = criterion(model(x), y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Hyperparameters

Parameter Description Typical Range
lr Model learning rate 1e-41e-3
betas Moment coefficients (0.9, 0.999)
gamma Agreement signal smoothing 0.850.99
lr_phi Policy learning rate 5e-45e-2
degree Polynomial degree 14

Configuration-Specific Selections

Dataset Architecture γ η_φ Degree
CIFAR-10 CNN 0.882 0.00312 1
CIFAR-10 ResNet-18 0.950 0.00500 2
FashionMNIST CNN 0.950 0.01000 2
FashionMNIST ResNet-18 0.957 0.00273 3

Experiments

Experiments use 30 epochs, cross-entropy loss, cosine annealing LR schedule, batch size 128, and AMP. ResNet-18 configurations include a 3-epoch linear warmup.

# CNN on CIFAR-10
python train.py --dataset cifar10 --arch cnn --optimizer pilot

# ResNet-18 on FashionMNIST
python train.py --dataset fashionmnist --arch resnet18 --optimizer pilot

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pilot_optimizer-0.1.0.tar.gz (15.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pilot_optimizer-0.1.0-py3-none-any.whl (10.7 kB view details)

Uploaded Python 3

File details

Details for the file pilot_optimizer-0.1.0.tar.gz.

File metadata

  • Download URL: pilot_optimizer-0.1.0.tar.gz
  • Upload date:
  • Size: 15.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for pilot_optimizer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c1e682513d2dc795495ec172d03d133a39aeb6738f7a3f3e59f9bc71dfe92fa6
MD5 60052032561d9552bfe12e85d5be0a50
BLAKE2b-256 6c3414dde7e056df72af9cef012d8ab69a2dcc6d4be63c37091bde54bd14ce25

See more details on using hashes here.

File details

Details for the file pilot_optimizer-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pilot_optimizer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 49ef76b437397165b5b18a9105b7b2dc3d30adbdfa96dfa89ba358692e13d70b
MD5 8a2887a18f2630c058674c27dfa889ea
BLAKE2b-256 364e0ac5d5cb510e984e398115be5e70704528009a16668ac802cdddc4d9108b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page