Skip to main content

JAX-native thermal sampling for discrete energy-based models

Project description

Hamon

JAX-native thermal sampling for discrete energy-based models.

PyPI Python License


Hamon is a JAX library for sampling from discrete probabilistic graphical models. It provides GPU-accelerated block Gibbs sampling, non-reversible parallel tempering with adaptive schedule optimization, and tools for building and training Ising models, RBMs, and other discrete energy-based models.

Built on Extropic AI's thrml foundation, Hamon diverges as an independent library with original algorithmic contributions and performance optimizations.

Why "Hamon"?

In Japanese swordsmithing, the hamon (刃文, "blade pattern") is the visible wave that appears along the edge of a katana after differential hardening. The smith coats the blade in clay — thin along the cutting edge, thick along the spine — then heats the steel to critical temperature and quenches it in water. The edge cools fast into hard martensite; the spine cools slowly into tough pearlite. The boundary between these two phases is the hamon: a pattern born entirely from a thermal process, where controlled temperature gradients reveal structure hidden in disordered steel.

The parallel to this library is direct. Hamon explores discrete energy landscapes by running chains at different temperatures and exchanging information across the thermal gradient. Structure emerges at the boundary between mixing regimes — hot chains explore freely, cold chains resolve fine detail, and the communication between them is what makes sampling work. The hamon on a blade is proof that a thermal process found the right boundary. The diagnostics in this library measure the same thing.

Installation

pip install hamon

For development:

git clone https://github.com/dek3rr/hamon.git
cd hamon
pip install -e ".[development,testing,examples]"

Requires Python ≥ 3.12 and a JAX installation (GPU setup guide).

Device routing

With CUDA jax installed, JAX places everything on the GPU — including the small, dispatch-bound programs where a CPU finishes several times faster. hamon's entry points (nrpt, tune_schedule, tune_chains, ising_sample, sample_states, sample_with_observation, …) therefore take a device argument:

  • "auto" (default) — with no accelerator visible, placement is untouched. Otherwise the work score n_chains × free nodes decides: small workloads run on the CPU, large ones on the accelerator. The default threshold (4096, the steady-state crossover measured on an RTX 5080) can be overridden with HAMON_DEVICE_THRESHOLD (calibrate yours with python benchmarks/device_crossover.py); HAMON_DEVICE=cpu|gpu|none forces a choice without code changes. Very short one-shot flows are compile-dominated and can favor the CPU regardless of size — pass device="cpu" for those, or set JAX_COMPILATION_CACHE_DIR so repeated runs skip GPU compilation entirely.
  • "cpu" / "gpu" — that platform, raising if it is not visible.
  • a concrete jax.Device — used as-is.
  • None — hamon never touches placement.

Routing re-commits the entry arrays (program tensors, states, β ladder) to the chosen device and returns outputs committed there; pass device=None to keep full manual control of placement. Orchestrators resolve the device once and reuse it across all tuning phases, so jit caches stay warm.

Quick example

import jax
import jax.numpy as jnp
from hamon import SpinNode, Block, SamplingSchedule, sample_states
from hamon.models import IsingEBM, IsingSamplingProgram, hinton_init

nodes = [SpinNode() for _ in range(5)]
edges = [(nodes[i], nodes[i + 1]) for i in range(4)]
model = IsingEBM(nodes, edges, jnp.zeros(5), jnp.ones(4) * 0.5, jnp.array(1.0))

free_blocks = [Block(nodes[::2]), Block(nodes[1::2])]
program = IsingSamplingProgram(model, free_blocks, clamped_blocks=[])

key = jax.random.key(0)
k_init, k_samp = jax.random.split(key, 2)
init_state = hinton_init(k_init, model, free_blocks, ())
schedule = SamplingSchedule(n_warmup=100, n_samples=1000, steps_per_sample=2)

samples = sample_states(k_samp, program, schedule, init_state, [], [Block(nodes)])

Non-reversible parallel tempering

Hamon implements adaptive NRPT based on Syed et al. (2021), with vectorized swaps that exploit the temperature-linearity of Ising energies. The primary interface is autotuningautotune / autosample discover the chain count, the local-exploration count, and the schedule for you:

from hamon import autosample

# Tunes N, gibbs_steps_per_round, and the β ladder, then draws from the target.
samples, report = autosample(
    jax.random.key(0),
    n_samples=2000,
    ebm=ebm,                  # a single template EBM (any β)
    program=program,
    init_factory=init_factory,  # (n_chains, ebms, programs) -> [init per chain]
    clamp_state=[],
    beta_range=(0.0, 1.0),
)
print(report.summary())       # N, n_expl, Λ, round-trip efficiency

# Or keep the tuned plan and draw repeatedly without re-tuning:
plan = autotune(jax.random.key(1), ebm=ebm, program=program,
                init_factory=init_factory, clamp_state=[])
more = plan.sample(jax.random.key(2), 5000)

For Ising models, ising_sample wraps this in a one-liner (biases, edges, weights → samples) and autotunes everything automatically.

Key features of the NRPT implementation:

  • Full autotuning: autotune runs chain count → exploration count → schedule in dependency order, reusing the schedule across exploration probes and never re-discovering N; returns an NRPTPlan for cheap repeated draws
  • Device-calibrated exploration: tune_exploration picks gibbs_steps_per_round by maximizing ESS per measured wall-second, so it self-calibrates (n_expl=1 on a compute-bound CPU, n_expl>1 on a dispatch-bound GPU where extra sweeps are nearly free — measured 1.7–2.3× ESS/sec)
  • Vectorized swaps: 1 energy evaluation per chain (not 4 per pair), all non-overlapping swaps execute simultaneously via permutation indexing
  • Temperature-linear mode: one β = 1 base program serves every chain; interactions are scaled by each chain's β inside the kernel, so no per-chain program construction and n_chains× less interaction memory
  • Chain count discovery: tune_chains probes for the right N from Λ
  • Adaptive scheduling: tune_schedule equalizes rejection rates, minimizing the global communication barrier Λ
  • Round trip tracking: estimates Λ and predicted optimal rate τ̄ = 1/(2+2Λ)
  • Effective sample size: effective_sample_size reports per-variable ESS (the honest denominator on Monte-Carlo error); folded into report_nrpt_diagnostics
  • Log normalizing constant: opt-in NRPTEnergyObserver + thermodynamic_integration recover log Z / model evidence / free energy from the tempering energies — the quantity ordinary MCMC discards
  • Compile cache by default: autotune enables JAX's persistent compile cache to amortize the multi-probe recompiles across runs

Log Z and effective sample size

import jax.numpy as jnp
from hamon import NRPTEnergyObserver, nrpt_log_normalizing_constant
from hamon.nrpt import tune_schedule

obs = NRPTEnergyObserver(n_chains=8)
states, stats = tune_schedule(
    jax.random.key(0),
    init_states=[init_state] * 8,
    clamp_state=[],
    n_rounds=500,
    gibbs_steps_per_round=5,
    initial_betas=jnp.linspace(0.0, 1.0, 8),
    ebm=ebm,
    program=program,
    observer=obs,  # opt-in: accumulates mean energy on the production run
)

# log Z(1) for an n-spin model (β=0 reference is uniform over 2**n states).
log_z = nrpt_log_normalizing_constant(stats, log_z0=len(nodes) * jnp.log(2.0))

# Effective sample size of the cold-chain trace.
from hamon import effective_sample_size, report_nrpt_diagnostics

report = report_nrpt_diagnostics(stats, samples=my_cold_chain_samples)
print(report.summary())  # includes ess(min)/ess(median)/ess_fraction

What makes Hamon fast

All chains run in one kernel. Parallel tempering uses jax.vmap over chains instead of a Python loop. Compile time is constant regardless of chain count.

No redundant work in the sampler loop. Global state is threaded through lax.scan as a carry. Block updates write back via contiguous slice updates with static offsets (scatters only as a fallback for non-contiguous layouts) instead of rebuilding the full state tensor each iteration.

Energy evaluation skips unnecessary work. Pre-built BlockSpec objects are passed through directly — no reconstruction on every energy() call. Padded interaction entries are pre-zeroed at program construction, so samplers skip the per-step active-mask multiply.

Accumulator dtypes are explicit. The moment accumulator pins its dtype at construction, and conditional samplers accumulate in the weights' dtype, avoiding silent promotion on GPU and seeding float64 sums with float32 zeros.

Citing Hamon

If you use Hamon in your research, please cite:

@software{kerr2026hamon,
    author       = {Kerr, Douglas E. Jr.},
    title        = {Hamon: JAX-Native Thermal Sampling for Discrete Energy-Based Models},
    year         = {2026},
    url          = {https://github.com/dek3rr/hamon},
    version      = {0.7.0},
    license      = {Apache-2.0},
}

Hamon's block sampling and PGM infrastructure is derived from thrml (v0.1.3) by Extropic AI, licensed under Apache 2.0. See NOTICE for full attribution. If you use the underlying block Gibbs framework, please also cite:

@misc{jelincic2025efficient,
    title        = {An efficient probabilistic hardware architecture for diffusion-like models},
    author       = {Andraž Jelinčič and Owen Lockwood and Akhil Garlapati and Guillaume Verdon and Trevor McCourt},
    year         = {2025},
    eprint       = {2510.23972},
    archivePrefix= {arXiv},
    primaryClass = {cs.LG},
}

The non-reversible parallel tempering implementation is based on:

Syed, S., Bouchard-Côté, A., Deligiannidis, G., & Doucet, A. (2021). Non-Reversible Parallel Tempering: a Scalable Highly Parallel MCMC Scheme. arXiv:1905.02939

License

Apache 2.0. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hamon-0.7.0.tar.gz (174.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hamon-0.7.0-py3-none-any.whl (116.7 kB view details)

Uploaded Python 3

File details

Details for the file hamon-0.7.0.tar.gz.

File metadata

  • Download URL: hamon-0.7.0.tar.gz
  • Upload date:
  • Size: 174.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for hamon-0.7.0.tar.gz
Algorithm Hash digest
SHA256 4d862bfe22c99047ec50fe19cfff2f692982dbbdbb170ddce9b0af37932f4e34
MD5 bce3fdf6e29b1c64e14702298a7e8f42
BLAKE2b-256 0adc22fcdbe464db6c996b29934a9edbaedee4d647afdf57e952b41567bde006

See more details on using hashes here.

Provenance

The following attestation bundles were made for hamon-0.7.0.tar.gz:

Publisher: publish.yml on dek3rr/hamon

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file hamon-0.7.0-py3-none-any.whl.

File metadata

  • Download URL: hamon-0.7.0-py3-none-any.whl
  • Upload date:
  • Size: 116.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for hamon-0.7.0-py3-none-any.whl
Algorithm Hash digest
SHA256 43734d2f30726b17043b26764f742df324a6c91aa55010188399810f009f2674
MD5 13303bd07a9f5aada1ec35508748db98
BLAKE2b-256 29df66d1c8fbf370c62cec87821536723ba657332a41d5bef38bf6525848b769

See more details on using hashes here.

Provenance

The following attestation bundles were made for hamon-0.7.0-py3-none-any.whl:

Publisher: publish.yml on dek3rr/hamon

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page