Forward Pass Weight Amortization Protocol — invert the inference loop for large transformer models.
Project description
fpwap — Forward Pass Weight Amortization Protocol
A single-purpose library for running activation extraction over large transformer models whose weights don't fit in your GPU, across datasets of thousands of prompts, on consumer hardware, at full precision.
The regime
You're a mech-interp researcher. Your model is bigger than your VRAM. Your dataset is thousands of prompts. Adjacent tools each fail in a way that changes what you're studying:
- Quantization (bitsandbytes, GPTQ) changes the activations you're trying to read.
- Inference servers (vLLM, TGI) optimize next-token throughput, not residual-stream extraction.
accelerate.cpu_offloadstreams weights once per batch — 10k prompts × 80 layers on a 70B model is hundreds of TB of weight I/O, hours of wall-clock per dataset pass.- Cloud GPUs break your interactive iteration loop and cost hundreds per experiment.
fpwap inverts the inference loop: load each layer once, stream the whole dataset through it, spill intermediates to disk, move on. Total weight I/O drops from O(N_batches × N_layers) to O(N_layers). A 10k-sample Llama-3.1-70B extraction on a 32 GB consumer GPU runs in roughly the wall-clock of a single batch under the naive approach — with the same weights, no quantization, no cloud.
Aspirational performance
Targets, not measurements. These are the numbers fpwap is being built to; each row unlocks only after its milestone lands (70B gates on the bit-perfect test; 405B gates on the mmap-from-HF-cache path). Replaced by measured benchmarks as they come in.
Reference machine
| Component | Spec |
|---|---|
| GPU | NVIDIA RTX 5090, 32 GB VRAM |
| CPU | Modern desktop-class, 16+ cores |
| RAM | 128 GB DDR5 |
| Storage | NVMe SSD (Gen 4+), ≥ 1 TB free |
| Interconnect | PCIe 5.0 x16 |
| Network | None — fully local, no cloud |
Dataset-scale activation extraction (10,000 prompts × 256 tokens = 2.56M tokens)
Residual stream (residual_post) captured at every layer, pooled to last token, persisted to disk. RawActivations(layers="all").
| Model | Weights (bf16) | Loading strategy | Wall-clock target | Throughput target | vs. naive accelerate.cpu_offload |
|---|---|---|---|---|---|
| Llama-3.1-8B | 16 GB | cpu_offload |
≤ 8 min | ≥ 5,000 tok/s | ≥ 4× faster |
| Llama-3.1-70B | 140 GB | disk_offload |
≤ 45 min | ≥ 950 tok/s | ≥ 4× faster (naive ≈ 3 h) |
| Llama-3.1-405B | 810 GB | mmap_from_cache |
≤ 4 h | ≥ 180 tok/s | naive infeasible (OOM in RAM) |
Throughput is end-to-end tokens per second — total tokens processed (samples × seq_len) divided by wall-clock from fpwap(...).run() entry to return, including weight I/O, forward, callbacks, and buffer write.
Single-pass cost per layer (Llama-3.1-70B, 1.75 GB weights per layer)
The inner loop that fpwap is optimizing. On the reference machine, per layer, per full sweep of 10k × 256-token samples:
| Phase | Budget | Notes |
|---|---|---|
| Weight load | ≤ 1.0 s | NVMe → CPU → GPU, disk_offload path; once per layer, not once per batch |
| Forward | ≤ 15 s | 10k samples, bf16, batched at engine's discretion |
| Callback | ≤ 1.0 s | Aggregate across all registered callbacks for this layer |
| Buffer write | ≤ 1.0 s | Pooled activations to memmap; raw [N, S, H] budget is higher |
| Per-layer total | ≤ 18 s | × 80 layers ≈ 24 min (leaves headroom vs. 45 min end-to-end target) |
Overhead budgets
| Surface | Budget | Why |
|---|---|---|
| Profile + progress, combined | < 1% wall-clock | Has to stay on by default — see the Observability section |
verify=True (vs. naive cpu_offload at every layer) |
2–3× slower | Correctness debugging only; not for production runs |
| Preflight | < 5 s | Rejects infeasible configurations before GPU contact |
The API
One verb. One callback class. One result.
from fpwap import Sweep
from fpwap.callbacks.common import RawActivations, IncrementalPCA, DiffOfMeans
run = Sweep(
model="meta-llama/Llama-3.1-70B",
dataset=my_dataset, # iterable of {"input_ids": ..., "label": ...}
seq_len=256,
callbacks=[
RawActivations(layers=[40, 45, 50]), # pooled by default
IncrementalPCA(layers="all", n_components=64),
DiffOfMeans(layers="all", label_fn=lambda s: s["label"]),
],
)
plan = run.preflight()
print(plan.summary()) # check feasibility before GPU contact
result = run.run()
acts = result.activations(layer=45, hook="residual_post") # [N, H]
basis = result.artifact("pca_basis", layer=45)
That is the entire user-facing surface for read-only workflows. No backend objects to construct. No batch_size knob to foot-gun. No loader / accumulator triple to wire up. Construction is cheap; .preflight() inspects the plan and rejects infeasible configurations with actionable messages; .run() executes.
Layer indexing
Hook names follow the HF hidden_states convention:
| Hook | Equals |
|---|---|
residual_pre at layer L |
hidden_states[L] (input to block L) |
residual_post at layer L |
hidden_states[L+1] (output of block L) |
attn_out at layer L |
attention sub-layer output at block L |
mlp_out at layer L |
MLP sub-layer output at block L |
No off-by-one translation at the call site.
Writing your own callback
Subclass Callback. Declare which layers and hooks you want; implement on_batch. Return an Emit to persist a tensor, a WriteBack to modify the residual before the next layer, or None to no-op.
from fpwap import Callback, Emit
class LastTokenLogNorm(Callback):
target_layers = [32]
target_hooks = ("residual_post",)
phase = "read"
def on_batch(self, layer_idx, hook, acts, sample_ids):
return Emit(acts[:, -1, :].norm(dim=-1).log())
Write-backs and multi-pass workflows
The same entry point handles steering. A callback with phase = "write" modifies the residual stream between layers; artifacts from one run feed the next.
from fpwap.callbacks.common import SteerInBasis
# Pass 2: steer in the basis fit during pass 1
steer = Sweep(
model="meta-llama/Llama-3.1-70B",
dataset=my_dataset,
seq_len=256,
callbacks=[
SteerInBasis(
basis_artifact=result.artifact("pca_basis", layer=45),
direction_idx=0,
alpha=2.0,
layers=[45],
),
],
)
steered = steer.run()
Observability
Performance is the product, so every run is profiled by default with a measurement overhead small enough (target: under 1% wall-clock) that you never have to opt in. When a run is slower than you want, the answer is already in result.profile — no re-running with profile=True.
result = run.run()
result.profile.summary() # human-readable breakdown per layer
result.profile.by_phase() # load / forward / callback / write
result.profile.slowest_layer() # where the time went
result.profile.bytes_moved() # weight I/O, buffer I/O
Interactive progress is on by default — a tqdm-style bar across layers × batches, because a run on the workstation under your desk should not sit silent for 40 minutes. Disable with progress=False; pass a callable (progress=my_reporter) to stream events into wandb, rich, or any other backend.
Known cliff: CUDA allocator fragmentation on K-sweep configs
If a K-sweep run on tight VRAM (K-packed sweeps, chunk_size=1, large per-K residual buffer) shows episodic multi-minute pauses every few sweeps with the process stuck in D-state but NVMe mostly idle in iostat, the cause is almost certainly CUDA caching-allocator fragmentation, not host I/O. Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before launch — segments grow contiguously on demand and the cliff disappears (verified on a 70B / K=30 / 14k repro: 2:20 → 0:55 wall, no other change). See #72 for the diagnostic walk.
Reference callbacks
Four callbacks ship with the library as examples and integration tests:
RawActivations— persist per-sample activations, pooled (last_token_only=True) by default to avoid an[N, S, H]memory landmine.IncrementalPCA— fit a PCA basis over the entire dataset in a single pass.DiffOfMeans— compute per-class activation means for binary-labeled data.SteerInBasis— additive intervention in a pre-computed basis;phase = "write".
Anything beyond these four is a consumer's problem.
Integrating fpwap into a research codebase
The recommended shape is a single classmethod on your codebase's activation-source type, inserted above any per-batch sharding your framework does:
class Activations:
@classmethod
def from_fpwap(cls, model_id, prompts, layers, pool="last_token"):
run = Sweep(
model=model_id,
dataset=_as_dataset(prompts),
seq_len=...,
callbacks=[
RawActivations(
layers=layers,
last_token_only=(pool == "last_token"),
),
],
)
return cls.from_result(run.run())
Branch use_fpwap at your dispatch layer — the same place you'd branch between from_model, from_goodfire, etc. — not inside a per-batch loop. fpwap's value (amortizing layer loads across the whole dataset) only materializes if it sees the dataset; if your framework shards externally and calls an extractor per shard, lift the dispatch up one level before integrating.
Scope
fpwap is a plumbing layer. It produces activations and accepts transforms. It does not know what a probe is. Linear probe fitting, SAE training, attribution analysis, and any other statistical modeling of activations belong in consumer libraries. If it requires knowing what a probe is, it's out of scope.
Related work
The loop inversion at the heart of fpwap — load each layer once, stream the dataset through it — was explored independently by FlexGen (Sheng et al., ICML 2023) for high-throughput generative inference on a single GPU. FlexGen calls this a "zig-zag block schedule" and proves it is within 2× of I/O-optimal (Theorem 4.1) — a result that applies directly to fpwap's loop, since our schedule is the same modulo KV cache. FlexGen solves a harder scheduling problem (KV cache placement across GPU/CPU/disk, multi-step autoregressive decoding, CPU compute delegation) and applies 4-bit group-wise quantization to further compress weights. fpwap targets a narrower regime — forward-pass activation extraction for mechanistic interpretability — where full precision is non-negotiable and generation is not needed, so the implementation is much simpler. The absence of KV cache and autoregressive decoding also means fpwap's cost model has fewer free variables, making strategy selection tractable without an LP solver.
Status
Llama-3.1-405B on a single RTX 5090 (32 GB VRAM), streaming 803 GB of bf16 weights from NVMe — 45.7 tok/s in under 12 minutes. 70B at 10,000 prompts × 128 tokens hits 1,221 tok/s. That's the regime fpwap exists for: the model doesn't fit in VRAM (or even RAM), the dataset is thousands of prompts, and no quantization is involved. Measured on the reference machine (RTX 5090, 128 GB DDR5, PCIe 5.0 NVMe):
| Model | Path | Samples × seq_len | Throughput (bf16) | SPEC target |
|---|---|---|---|---|
| Llama-3.1-405B-Instruct | streaming, prefetch | 256 × 128 | 45.7 tok/s | ≥ 180 |
| Llama-3.3-70B-Instruct | streaming, prefetch | 10,000 × 128 | 1,221 tok/s | ≥ 950 |
| Llama-3.3-70B-Instruct | streaming | 1,024 × 128 | 1,026 tok/s | ≥ 950 |
| Llama-3.1-8B-Instruct | streaming | 1,024 × 128 | 10,442 tok/s | ≥ 5,000 |
| Llama-3.1-8B-Instruct | preloaded | 256 × 128 | 11,894 tok/s | ≥ 5,000 |
The 405B number is end-to-end across 126 layers streaming 803 GB of weights from NVMe SSD at 1.12 GB/s sustained, with prefetch fully hiding disk reads behind compute (0.000s load per layer at steady state). The 70B hero number is end-to-end across 80 layers with a pinned-CPU residual buffer (21 GB), async D2H, and a worker-thread weight prefetch that overlaps layer L+1's safetensors read with layer L's compute.
Baseline sanity: an 8B streaming-vs-naive head-to-head shows a 7.25×
speedup at 1024 × 128 (SPEC §17 ratio target ≥ 4×). The naive baseline
is accelerate.cpu_offload at 1,440 tok/s, reproducible via
scripts/benchmark.py --mode naive. 70B can't ratio-test on this machine
(141 GB bf16 > 128 GB RAM for cpu_offload); the 70B claim is absolute
throughput.
Correctness: tests/gpu/test_real_llama_bit_exact.py runs Llama-3.2-1B in bf16
on CUDA and compares every layer's residual_post against a naive HF forward —
bit-exact (torch.equal) at every real token position. When microbatch_size
equals the naive batch size, bf16 is deterministic; at different microbatch
sizes, outputs diverge by LSB accumulation noise (see the memory note on
bf16_microbatch_determinism).
What's wired: pre-loaded and streaming model paths, Sweep + Callback +
Result API, padded-batch + attention-mask propagation, RoPE-aware Llama
plumbing, GPT-2 plumbing, all four hooks (residual_pre, attn_out,
mlp_out, residual_post) with fast-path block forward when no sub-layer
hook is wanted and WriteBack at every hook (sub-layer WriteBack is
threaded through the block mid-forward so the modified tensor actually
affects downstream compute), all four reference callbacks shipped
(RawActivations, IncrementalPCA, DiffOfMeans, SteerInBasis),
result.activations(...), tqdm progress plus callable progress=reporter
emitting ProgressEvents for wandb/rich sinks, pinned-CPU
buffer_device="cpu" with async D2H copy (so oversized residual buffers
don't block compute), worker-thread concurrent weight prefetch on the
streaming path (layer L+1's safetensors read + H2D overlap with layer L's
compute), MemmapBackend for disk-backed emits,
ProfileReport.throughput_tok_per_s() / weight_bandwidth_gb_per_s(),
verify=True fail-fast against a naive-forward baseline (pre-loaded
models), per-layer on_layer_end artifacts collected into
result.artifacts.
Model families covered by the structural matcher: Llama, Mistral, Qwen2,
Gemma, DeepSeek-V2, and any future HF causal LM exposing the same
model.{embed_tokens, layers, rotary_emb} layout. GPT-2 covered by its
own plumbing.
What's not yet: checkpoint/resume, NVMe-backed ResidualBuffer,
verify=True on the streaming path (pre-loaded only).
See SPEC.md for the full design.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fpwap-0.1.1.tar.gz.
File metadata
- Download URL: fpwap-0.1.1.tar.gz
- Upload date:
- Size: 215.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f9f1294bfaaab9e49797e62fbd3be5c7dae086bb193c37f4f567384f2549671c
|
|
| MD5 |
b8fd593bd04c94c3ae3586890f26ff4f
|
|
| BLAKE2b-256 |
fd34786686b2ac6a6ee629f86971139988796f706f8f0e6fc83a05e43696145c
|
Provenance
The following attestation bundles were made for fpwap-0.1.1.tar.gz:
Publisher:
release.yml on AlliedToasters/fpwap
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
fpwap-0.1.1.tar.gz -
Subject digest:
f9f1294bfaaab9e49797e62fbd3be5c7dae086bb193c37f4f567384f2549671c - Sigstore transparency entry: 1780210916
- Sigstore integration time:
-
Permalink:
AlliedToasters/fpwap@4f859ed891cc70b568e36e08c06d41fb10f29c72 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/AlliedToasters
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@4f859ed891cc70b568e36e08c06d41fb10f29c72 -
Trigger Event:
push
-
Statement type:
File details
Details for the file fpwap-0.1.1-py3-none-any.whl.
File metadata
- Download URL: fpwap-0.1.1-py3-none-any.whl
- Upload date:
- Size: 57.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d052dd627368542c94bf3118e8ae10db416a9bac1c3e48037879f62a56c0ade0
|
|
| MD5 |
c4fcc9edc0258316431fdcb0b70fe0d5
|
|
| BLAKE2b-256 |
002f788bf9002887c1d4a78173d3ed407f72412189bf8a39b9d5e176715f99b4
|
Provenance
The following attestation bundles were made for fpwap-0.1.1-py3-none-any.whl:
Publisher:
release.yml on AlliedToasters/fpwap
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
fpwap-0.1.1-py3-none-any.whl -
Subject digest:
d052dd627368542c94bf3118e8ae10db416a9bac1c3e48037879f62a56c0ade0 - Sigstore transparency entry: 1780211105
- Sigstore integration time:
-
Permalink:
AlliedToasters/fpwap@4f859ed891cc70b568e36e08c06d41fb10f29c72 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/AlliedToasters
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@4f859ed891cc70b568e36e08c06d41fb10f29c72 -
Trigger Event:
push
-
Statement type: