Skip to main content

A Python package for beam simulation and parameter inference using raster scanning models

Project description

Beam-Infer

Beam-Infer is a Python library for beam parameter inference from raster-scanned images. It provides three high-level wrappers for parameter estimation using optimization and MCMC methods, with full control over configuration while maintaining a simple, stable API.

Note: This package is based on the original work by Joel Henriksson (Lund University, LTH) from his master's thesis "Inverse Problems in Proton Beam Imaging at ESS: Analysis and Numerical Methods". The core calculation base and algorithms were implemented by Joel Henriksson. This package integrates that work into ESS systems.

Original Repository: ess-beam-imaging-inverse-problems

Numerical equivalence to the legacy base implementation (forward model, loss, optimizer pipeline, Poisson MCMC) is tested and documented; see docs/EQUIVALENCE.md. This package focuses on the calculation layer only; EPICS integration and I/O are handled elsewhere.

Features

  • Complete Pipeline: estimate_with_optimizer handles preprocessing, grid construction, initialization, and optimization
  • Multistart Optimization: run_multistart_minimization for robust parameter recovery
  • MCMC Sampling: run_mcmc_inference for Bayesian parameter estimation with uncertainty quantification
  • Flexible Configuration: Rich configuration via dataclasses for all inference methods
  • JAX-based: Fast, GPU-accelerated computations using JAX
  • Reproducible: Optional random seeds for deterministic multistart runs

Installation

From Source

git clone https://gitlab.esss.lu.se/hugovalim/beam-infer.git
cd beam-infer
pip install -e .

Development Installation

pip install -e ".[dev]"

Quick Start

Full Pipeline (Recommended)

The estimate_with_optimizer function provides a complete pipeline from raw camera image to estimated parameters:

import numpy as np
import beam_infer

# Raw camera image (2D numpy array)
raw_image = np.array(...)  # Your image data

# Metadata (known quantities)
meta = beam_infer.Metadata(
    pulse_duration_ms=0.05,  # Pulse duration in milliseconds
    fx=39.55,                 # Horizontal scan frequency (kHz) - optional, fixes if provided
    fy=29.05,                 # Vertical scan frequency (kHz) - optional, fixes if provided
)

# Configuration
config = beam_infer.OptimizerConfig(
    max_iterations=100,       # Maximum L-BFGS-B iterations per run
    multistart={"n_starts": 5, "type": "uniform"},  # Multistart configuration
    regularization={"params": {"sigx": 0.1, "sigy": 0.1}},  # Optional regularization
    random_seed=42,           # Optional: for reproducible multistart runs
    verbose=False,            # Print iteration details
)

# Run estimation
result = beam_infer.estimate_with_optimizer(raw_image, meta=meta, config=config)

# Access results
print(f"Estimated parameters: {result.estimated_parameters}")
print(f"Final loss: {result.objective_value:.6e}")
print(f"Success: {result.success}")
print(f"Crop bounds: {result.crop_bounds}")

Multistart Minimization

For more control over the optimization process:

import numpy as np
import jax.numpy as jnp
import beam_infer

# Preprocessed image and grids (you handle preprocessing)
observed_image = np.array(...)  # Already preprocessed
X, Y = ...  # 2D coordinate grids matching image shape
t_vals = ...  # 1D time array

# Initial guess (10 parameters)
k0 = jnp.array([60.0, 20.0, 13.5, 5.05, 0.0, 0.0, 39.55, 29.05, 0.0, 0.0])

# Bounds (optional)
lower = jnp.array([0.0, 0.0, 2.0, 2.0, -20.0, -20.0, 20.0, 20.0, -np.pi, -np.pi])
upper = jnp.array([100.0, 50.0, 20.0, 20.0, 20.0, 20.0, 50.0, 50.0, np.pi, np.pi])

# Configuration
config = beam_infer.MinimizationConfig(
    initial_guess=k0,
    bounds=(lower, upper),
    max_iterations=100,
    verbose=False,
)

# Run multistart optimization
result = beam_infer.run_multistart_minimization(
    observed_image,
    X, Y, t_vals,
    config,
    num_starts=10,
    random_seed=42,  # For reproducibility
)

# Access best result
best = result.best_result
print(f"Best parameters: {best.estimated_parameters}")
print(f"Best loss: {best.objective_value:.6e}")

# Access all results (sorted by loss)
for i, r in enumerate(result.results):
    print(f"Run {i}: loss = {r.objective_value:.6e}")

MCMC Inference

For Bayesian parameter estimation with uncertainty quantification:

import numpy as np
import beam_infer

# Observed image (counts or normalized intensity)
observed_image = np.array(...)  # 2D image
X, Y = ...  # 2D coordinate grids
t_vals = ...  # 1D time array

# Bounds
lower = np.array([0.0, 0.0, 2.0, 2.0, -20.0, -20.0, 20.0, 20.0, -np.pi, -np.pi])
upper = np.array([100.0, 50.0, 20.0, 20.0, 20.0, 20.0, 50.0, 50.0, np.pi, np.pi])

# MCMC configuration
config = beam_infer.MCMCConfig(
    lower_bounds=lower,
    upper_bounds=upper,
    num_samples=4000,      # Number of samples to collect
    burn_in=1000,          # Burn-in period
    thin=2,                # Thinning (keep every Nth sample)
    prior_type="uniform",  # or "gaussian"
    seed=42,               # Random seed
)

# Run MCMC
result = beam_infer.run_mcmc_inference(observed_image, X, Y, t_vals, config)

# Access results
print(f"Posterior mean: {result.posterior_mean}")
print(f"Posterior std: {result.posterior_std}")
print(f"Acceptance rate: {result.acceptance_rate:.2%}")
print(f"Samples shape: {result.samples.shape}")  # (num_samples, 10)

Supported Public API

The library exposes three main entrypoints plus parallel tempering helpers:

  1. estimate_with_optimizer — Complete pipeline (preprocessing + optimization)
  2. run_multistart_minimization — Multistart L-BFGS-B optimization
  3. run_mcmc_inference — Metropolis-Hastings MCMC sampling
  4. Parallel temperingbuild_beta_ladder, sample_from_prior, run_parallel_tempering for replica-exchange MCMC (see API Reference)

Each entrypoint has corresponding configuration and result dataclasses. See the API Reference for complete documentation.

Parameters

The beam model uses 10 parameters (in order):

Index Name Description Units
0 Ax Horizontal raster amplitude mm
1 Ay Vertical raster amplitude mm
2 sigx Horizontal beam width (std dev) mm
3 sigy Vertical beam width (std dev) mm
4 cx Horizontal center offset mm
5 cy Vertical center offset mm
6 fx Horizontal scan frequency kHz
7 fy Vertical scan frequency kHz
8 phix Horizontal phase offset radians
9 phiy Vertical phase offset radians

Configuration Guide

OptimizerConfig

Controls the full pipeline (estimate_with_optimizer):

config = beam_infer.OptimizerConfig(
    max_iterations=100,              # Max L-BFGS-B iterations per run
    multistart={"n_starts": 5, "type": "uniform"},  # Multistart config
    regularization={"params": {"sigx": 0.1}},  # Regularization strengths
    random_seed=42,                  # Optional: for reproducible multistart
    verbose=False,                    # Print iteration details
    debug=False,                      # Return intermediate results
)

Multistart options:

  • "type": "uniform" - Uniform random sampling within bounds
  • "type": "gaussian" - Gaussian sampling around initial guess

Regularization:

  • Dictionary mapping parameter names to penalty strengths
  • Penalizes deviation from initial guess: lambda * (param - param0)^2

MinimizationConfig

Controls single/multistart optimization:

config = beam_infer.MinimizationConfig(
    initial_guess=k0,                # Initial parameter vector (10-D)
    bounds=(lower, upper),           # Optional: parameter bounds
    max_iterations=100,              # Max L-BFGS-B iterations
    verbose=False,                   # Print iteration details
    regularization={"sigx": 0.1},    # Optional: regularization strengths
)

MCMCConfig

Controls MCMC sampling:

config = beam_infer.MCMCConfig(
    lower_bounds=lower,              # Lower bounds (10-D array)
    upper_bounds=upper,              # Upper bounds (10-D array)
    num_samples=4000,                # Number of samples to collect
    burn_in=1000,                    # Burn-in period
    thin=2,                          # Thinning interval
    prior_type="uniform",            # "uniform" or "gaussian"
    prior_mean=None,                 # Required for Gaussian prior
    prior_std=None,                  # Required for Gaussian prior
    proposal_std=None,               # Optional: custom proposal std dev
    seed=42,                         # Random seed
    use_scan=True,                   # Use lax.scan (faster, default)
)

When to Use Which Method?

  • estimate_with_optimizer: Use when you have a raw camera image and want a complete pipeline. Handles preprocessing, grid construction, and optimization automatically.

  • run_multistart_minimization: Use when you need more control over preprocessing/grids, or want to run optimization from multiple starting points for robustness.

  • run_mcmc_inference: Use when you need uncertainty quantification, want to explore the posterior distribution, or need Bayesian inference.

Examples

See the examples/ directory for complete examples:

  • simple_estimation.py ⭐ - Simplest example using only public API (recommended for beginners)
  • basic_simulation.py - Basic beam image generation
  • parameter_variations.py - Effects of different parameters
  • raster_pattern.py - Visualize raster scanning patterns
  • inference_example.py - Parameter recovery from synthetic data
  • real_image_estimation.py - Real snapshot: optimizer + MCMC with comparison
  • comprehensive_example.py - Complete workflow demonstration

Run examples:

# Simplest example (public API only)
python -m examples.simple_estimation

# Real image: optimizer vs MCMC (requires beam snapshot NPZ in examples/)
python -m examples.real_image_estimation

# Comprehensive example (recommended)
python -m examples.comprehensive_example

See examples/README.md for detailed documentation on all examples.

Requirements

  • Python >= 3.10
  • numpy >= 1.20.0
  • jax >= 0.4.0
  • jaxlib >= 0.4.0
  • scipy >= 1.7.0

GPU Support (Optional)

For GPU acceleration, install JAX with CUDA support:

# For CUDA 12.x
pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_release

# Verify GPU detection
python -c "import jax; print(jax.devices())"

Development

Running Tests

# After clone, fetch the reference implementation submodule (for equivalence tests)
git submodule update --init --recursive

pytest tests/

Code Formatting

black beam_infer/ tests/ examples/

Type Checking

mypy beam_infer/

License

MIT License - see LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

beam_infer-1.0.0.tar.gz (30.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

beam_infer-1.0.0-py3-none-any.whl (29.3 kB view details)

Uploaded Python 3

File details

Details for the file beam_infer-1.0.0.tar.gz.

File metadata

  • Download URL: beam_infer-1.0.0.tar.gz
  • Upload date:
  • Size: 30.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for beam_infer-1.0.0.tar.gz
Algorithm Hash digest
SHA256 83918e47d59777196803884b943ff217e26c1fbe71649d97a4f63cc089c7f7c8
MD5 8c51799641dd59b02da888b2cab4cb15
BLAKE2b-256 b67e6bfa3af2df7b5c398df6a4186721de2de96e81d8a3962ed9b0d0da16999f

See more details on using hashes here.

File details

Details for the file beam_infer-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: beam_infer-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 29.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for beam_infer-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2d949e60e182b53dc828d5050723aac3f1b101a71fccf0a073e1c5647518f5f3
MD5 c6e1b1cfb691ef6e1e9019c645f1aeff
BLAKE2b-256 d4af33eaf495685728ef81a0c0204ebb77d5ab89775f98f5d64b7e30405eb25d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page