Skip to main content

Forward Pass Weight Amortization Protocol — invert the inference loop for large transformer models.

Project description

fpwap — Forward Pass Weight Amortization Protocol

A single-purpose library for running activation extraction over large transformer models whose weights don't fit in your GPU, across datasets of thousands of prompts, on consumer hardware, at full precision.

The regime

You're a mech-interp researcher. Your model is bigger than your VRAM. Your dataset is thousands of prompts. Adjacent tools each fail in a way that changes what you're studying:

  • Quantization (bitsandbytes, GPTQ) changes the activations you're trying to read.
  • Inference servers (vLLM, TGI) optimize next-token throughput, not residual-stream extraction.
  • accelerate.cpu_offload streams weights once per batch — 10k prompts × 80 layers on a 70B model is hundreds of TB of weight I/O, hours of wall-clock per dataset pass.
  • Cloud GPUs break your interactive iteration loop and cost hundreds per experiment.

fpwap inverts the inference loop: load each layer once, stream the whole dataset through it, spill intermediates to disk, move on. Total weight I/O drops from O(N_batches × N_layers) to O(N_layers). A 10k-sample Llama-3.1-70B extraction on a 32 GB consumer GPU runs in roughly the wall-clock of a single batch under the naive approach — with the same weights, no quantization, no cloud.

Aspirational performance

Targets, not measurements. These are the numbers fpwap is being built to; each row unlocks only after its milestone lands (70B gates on the bit-perfect test; 405B gates on the mmap-from-HF-cache path). Replaced by measured benchmarks as they come in.

Reference machine

Component Spec
GPU NVIDIA RTX 5090, 32 GB VRAM
CPU Modern desktop-class, 16+ cores
RAM 128 GB DDR5
Storage NVMe SSD (Gen 4+), ≥ 1 TB free
Interconnect PCIe 5.0 x16
Network None — fully local, no cloud

Dataset-scale activation extraction (10,000 prompts × 256 tokens = 2.56M tokens)

Residual stream (residual_post) captured at every layer, pooled to last token, persisted to disk. RawActivations(layers="all").

Model Weights (bf16) Loading strategy Wall-clock target Throughput target vs. naive accelerate.cpu_offload
Llama-3.1-8B 16 GB cpu_offload ≤ 8 min ≥ 5,000 tok/s ≥ 4× faster
Llama-3.1-70B 140 GB disk_offload ≤ 45 min ≥ 950 tok/s ≥ 4× faster (naive ≈ 3 h)
Llama-3.1-405B 810 GB mmap_from_cache ≤ 4 h ≥ 180 tok/s naive infeasible (OOM in RAM)

Throughput is end-to-end tokens per second — total tokens processed (samples × seq_len) divided by wall-clock from fpwap(...).run() entry to return, including weight I/O, forward, callbacks, and buffer write.

Single-pass cost per layer (Llama-3.1-70B, 1.75 GB weights per layer)

The inner loop that fpwap is optimizing. On the reference machine, per layer, per full sweep of 10k × 256-token samples:

Phase Budget Notes
Weight load ≤ 1.0 s NVMe → CPU → GPU, disk_offload path; once per layer, not once per batch
Forward ≤ 15 s 10k samples, bf16, batched at engine's discretion
Callback ≤ 1.0 s Aggregate across all registered callbacks for this layer
Buffer write ≤ 1.0 s Pooled activations to memmap; raw [N, S, H] budget is higher
Per-layer total ≤ 18 s × 80 layers ≈ 24 min (leaves headroom vs. 45 min end-to-end target)

Overhead budgets

Surface Budget Why
Profile + progress, combined < 1% wall-clock Has to stay on by default — see the Observability section
verify=True (vs. naive cpu_offload at every layer) 2–3× slower Correctness debugging only; not for production runs
Preflight < 5 s Rejects infeasible configurations before GPU contact

The API

One verb. One callback class. One result.

from fpwap import Sweep
from fpwap.callbacks.common import RawActivations, IncrementalPCA, DiffOfMeans

run = Sweep(
    model="meta-llama/Llama-3.1-70B",
    dataset=my_dataset,                # iterable of {"input_ids": ..., "label": ...}
    seq_len=256,
    callbacks=[
        RawActivations(layers=[40, 45, 50]),               # pooled by default
        IncrementalPCA(layers="all", n_components=64),
        DiffOfMeans(layers="all", label_fn=lambda s: s["label"]),
    ],
)

plan = run.preflight()
print(plan.summary())                   # check feasibility before GPU contact

result = run.run()
acts  = result.activations(layer=45, hook="residual_post")   # [N, H]
basis = result.artifact("pca_basis", layer=45)

That is the entire user-facing surface for read-only workflows. No backend objects to construct. No batch_size knob to foot-gun. No loader / accumulator triple to wire up. Construction is cheap; .preflight() inspects the plan and rejects infeasible configurations with actionable messages; .run() executes.

Layer indexing

Hook names follow the HF hidden_states convention:

Hook Equals
residual_pre at layer L hidden_states[L] (input to block L)
residual_post at layer L hidden_states[L+1] (output of block L)
attn_out at layer L attention sub-layer output at block L
mlp_out at layer L MLP sub-layer output at block L

No off-by-one translation at the call site.

Writing your own callback

Subclass Callback. Declare which layers and hooks you want; implement on_batch. Return an Emit to persist a tensor, a WriteBack to modify the residual before the next layer, or None to no-op.

from fpwap import Callback, Emit

class LastTokenLogNorm(Callback):
    target_layers = [32]
    target_hooks = ("residual_post",)
    phase = "read"

    def on_batch(self, layer_idx, hook, acts, sample_ids):
        return Emit(acts[:, -1, :].norm(dim=-1).log())

Write-backs and multi-pass workflows

The same entry point handles steering. A callback with phase = "write" modifies the residual stream between layers; artifacts from one run feed the next.

from fpwap.callbacks.common import SteerInBasis

# Pass 2: steer in the basis fit during pass 1
steer = Sweep(
    model="meta-llama/Llama-3.1-70B",
    dataset=my_dataset,
    seq_len=256,
    callbacks=[
        SteerInBasis(
            basis_artifact=result.artifact("pca_basis", layer=45),
            direction_idx=0,
            alpha=2.0,
            layers=[45],
        ),
    ],
)
steered = steer.run()

Observability

Performance is the product, so every run is profiled by default with a measurement overhead small enough (target: under 1% wall-clock) that you never have to opt in. When a run is slower than you want, the answer is already in result.profile — no re-running with profile=True.

result = run.run()

result.profile.summary()            # human-readable breakdown per layer
result.profile.by_phase()           # load / forward / callback / write
result.profile.slowest_layer()      # where the time went
result.profile.bytes_moved()        # weight I/O, buffer I/O

Interactive progress is on by default — a tqdm-style bar across layers × batches, because a run on the workstation under your desk should not sit silent for 40 minutes. Disable with progress=False; pass a callable (progress=my_reporter) to stream events into wandb, rich, or any other backend.

Known cliff: CUDA allocator fragmentation on K-sweep configs

If a K-sweep run on tight VRAM (K-packed sweeps, chunk_size=1, large per-K residual buffer) shows episodic multi-minute pauses every few sweeps with the process stuck in D-state but NVMe mostly idle in iostat, the cause is almost certainly CUDA caching-allocator fragmentation, not host I/O. Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True before launch — segments grow contiguously on demand and the cliff disappears (verified on a 70B / K=30 / 14k repro: 2:20 → 0:55 wall, no other change). See #72 for the diagnostic walk.

Reference callbacks

Four callbacks ship with the library as examples and integration tests:

  • RawActivations — persist per-sample activations, pooled (last_token_only=True) by default to avoid an [N, S, H] memory landmine.
  • IncrementalPCA — fit a PCA basis over the entire dataset in a single pass.
  • DiffOfMeans — compute per-class activation means for binary-labeled data.
  • SteerInBasis — additive intervention in a pre-computed basis; phase = "write".

Anything beyond these four is a consumer's problem.

Integrating fpwap into a research codebase

The recommended shape is a single classmethod on your codebase's activation-source type, inserted above any per-batch sharding your framework does:

class Activations:
    @classmethod
    def from_fpwap(cls, model_id, prompts, layers, pool="last_token"):
        run = Sweep(
            model=model_id,
            dataset=_as_dataset(prompts),
            seq_len=...,
            callbacks=[
                RawActivations(
                    layers=layers,
                    last_token_only=(pool == "last_token"),
                ),
            ],
        )
        return cls.from_result(run.run())

Branch use_fpwap at your dispatch layer — the same place you'd branch between from_model, from_goodfire, etc. — not inside a per-batch loop. fpwap's value (amortizing layer loads across the whole dataset) only materializes if it sees the dataset; if your framework shards externally and calls an extractor per shard, lift the dispatch up one level before integrating.

Scope

fpwap is a plumbing layer. It produces activations and accepts transforms. It does not know what a probe is. Linear probe fitting, SAE training, attribution analysis, and any other statistical modeling of activations belong in consumer libraries. If it requires knowing what a probe is, it's out of scope.

Related work

The loop inversion at the heart of fpwap — load each layer once, stream the dataset through it — was explored independently by FlexGen (Sheng et al., ICML 2023) for high-throughput generative inference on a single GPU. FlexGen calls this a "zig-zag block schedule" and proves it is within 2× of I/O-optimal (Theorem 4.1) — a result that applies directly to fpwap's loop, since our schedule is the same modulo KV cache. FlexGen solves a harder scheduling problem (KV cache placement across GPU/CPU/disk, multi-step autoregressive decoding, CPU compute delegation) and applies 4-bit group-wise quantization to further compress weights. fpwap targets a narrower regime — forward-pass activation extraction for mechanistic interpretability — where full precision is non-negotiable and generation is not needed, so the implementation is much simpler. The absence of KV cache and autoregressive decoding also means fpwap's cost model has fewer free variables, making strategy selection tractable without an LP solver.

Status

Llama-3.1-405B on a single RTX 5090 (32 GB VRAM), streaming 803 GB of bf16 weights from NVMe — 45.7 tok/s in under 12 minutes. 70B at 10,000 prompts × 128 tokens hits 1,221 tok/s. That's the regime fpwap exists for: the model doesn't fit in VRAM (or even RAM), the dataset is thousands of prompts, and no quantization is involved. Measured on the reference machine (RTX 5090, 128 GB DDR5, PCIe 5.0 NVMe):

Model Path Samples × seq_len Throughput (bf16) SPEC target
Llama-3.1-405B-Instruct streaming, prefetch 256 × 128 45.7 tok/s ≥ 180
Llama-3.3-70B-Instruct streaming, prefetch 10,000 × 128 1,221 tok/s ≥ 950
Llama-3.3-70B-Instruct streaming 1,024 × 128 1,026 tok/s ≥ 950
Llama-3.1-8B-Instruct streaming 1,024 × 128 10,442 tok/s ≥ 5,000
Llama-3.1-8B-Instruct preloaded 256 × 128 11,894 tok/s ≥ 5,000

The 405B number is end-to-end across 126 layers streaming 803 GB of weights from NVMe SSD at 1.12 GB/s sustained, with prefetch fully hiding disk reads behind compute (0.000s load per layer at steady state). The 70B hero number is end-to-end across 80 layers with a pinned-CPU residual buffer (21 GB), async D2H, and a worker-thread weight prefetch that overlaps layer L+1's safetensors read with layer L's compute.

Baseline sanity: an 8B streaming-vs-naive head-to-head shows a 7.25× speedup at 1024 × 128 (SPEC §17 ratio target ≥ 4×). The naive baseline is accelerate.cpu_offload at 1,440 tok/s, reproducible via scripts/benchmark.py --mode naive. 70B can't ratio-test on this machine (141 GB bf16 > 128 GB RAM for cpu_offload); the 70B claim is absolute throughput.

Correctness: tests/gpu/test_real_llama_bit_exact.py runs Llama-3.2-1B in bf16 on CUDA and compares every layer's residual_post against a naive HF forward — bit-exact (torch.equal) at every real token position. When microbatch_size equals the naive batch size, bf16 is deterministic; at different microbatch sizes, outputs diverge by LSB accumulation noise (see the memory note on bf16_microbatch_determinism).

What's wired: pre-loaded and streaming model paths, Sweep + Callback + Result API, padded-batch + attention-mask propagation, RoPE-aware Llama plumbing, GPT-2 plumbing, all four hooks (residual_pre, attn_out, mlp_out, residual_post) with fast-path block forward when no sub-layer hook is wanted and WriteBack at every hook (sub-layer WriteBack is threaded through the block mid-forward so the modified tensor actually affects downstream compute), all four reference callbacks shipped (RawActivations, IncrementalPCA, DiffOfMeans, SteerInBasis), result.activations(...), tqdm progress plus callable progress=reporter emitting ProgressEvents for wandb/rich sinks, pinned-CPU buffer_device="cpu" with async D2H copy (so oversized residual buffers don't block compute), worker-thread concurrent weight prefetch on the streaming path (layer L+1's safetensors read + H2D overlap with layer L's compute), MemmapBackend for disk-backed emits, ProfileReport.throughput_tok_per_s() / weight_bandwidth_gb_per_s(), verify=True fail-fast against a naive-forward baseline (pre-loaded models), per-layer on_layer_end artifacts collected into result.artifacts.

Model families covered by the structural matcher: Llama, Mistral, Qwen2, Gemma, DeepSeek-V2, and any future HF causal LM exposing the same model.{embed_tokens, layers, rotary_emb} layout. GPT-2 covered by its own plumbing.

What's not yet: checkpoint/resume, NVMe-backed ResidualBuffer, verify=True on the streaming path (pre-loaded only).

See SPEC.md for the full design.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fpwap-0.1.0.tar.gz (214.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fpwap-0.1.0-py3-none-any.whl (56.6 kB view details)

Uploaded Python 3

File details

Details for the file fpwap-0.1.0.tar.gz.

File metadata

  • Download URL: fpwap-0.1.0.tar.gz
  • Upload date:
  • Size: 214.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for fpwap-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9075871cb03aa01f8423aaea4c4bbc5d13e79448e85e263dcde33f7871c2543e
MD5 070f00dd0e60aae1ce6879b67a12fb3b
BLAKE2b-256 64cb95c6607e5a19cb63180a8f26029660e76596c3681d1e875403c6dd6f57b6

See more details on using hashes here.

Provenance

The following attestation bundles were made for fpwap-0.1.0.tar.gz:

Publisher: release.yml on AlliedToasters/fpwap

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file fpwap-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: fpwap-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 56.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for fpwap-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a660b9ab204a03017deb3ccdc31239d0239189b31a369627c1a6ca97c8bdb746
MD5 6899194de452343af56cd8dbef589115
BLAKE2b-256 22d0a77f43fb02acd67afbb4da7a492044274a9aa46b491d1cabc2841e216bba

See more details on using hashes here.

Provenance

The following attestation bundles were made for fpwap-0.1.0-py3-none-any.whl:

Publisher: release.yml on AlliedToasters/fpwap

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page