Skip to main content

GPU-native backward-induction MDP solver for continuous-state stochastic dynamic programs.

Project description

bellgrid

GPU-native backward-induction for continuous-state stochastic dynamic programs.

bellgrid solves Bellman equations exactly (up to interpolation error) across the entire state space. It's opinionated about backward induction, vectorization, and constraints; unopinionated about your application domain. Composes K continuous states with discrete-state primitives (DiscreteState, MarkovChain) and any mix of continuous and discrete actions. Supports asinh/log-warped grids, scalar / multivariate Gauss-Hermite shock quadrature, and a JIT-compiled multilinear kernel that runs on CPU or CUDA.

Quick start

pip install bellgrid                                                # CPU-only torch from PyPI
pip install bellgrid --extra-index-url https://download.pytorch.org/whl/cu126   # GPU

A minimal Merton consumption-portfolio (log utility, single risky asset, lognormal returns):

import math, torch
from bellgrid import Problem, ContinuousState, ContinuousAction, solve
from bellgrid.grids import WarpedGrid, RegularGrid
from bellgrid.shocks import Normal
from bellgrid.solvers import BackwardInduction

beta, mu, sigma = 0.96, 0.04, 0.15
# closed-form coefficients for V(w) = A + B log(w)
B = 1.0 / (1.0 - beta)
A = math.log(1 - beta) / (1 - beta) + (beta / (1 - beta) ** 2) * (math.log(beta) + mu)

def transition(state, action, shock, t):
    return {"wealth": (state["wealth"] - action["consume"]) * torch.exp(mu + sigma * shock["z"])}

def reward(state, action, shock, t):
    return torch.log(action["consume"])

problem = Problem(
    states=[ContinuousState("wealth", warp="asinh", range=(1e-3, 200.0))],
    actions=[ContinuousAction("consume", bounds=(1e-6, "wealth"))],
    transition=transition,
    reward=reward,
    shocks=[Normal("z", sigma=1.0)],
    horizon=range(0, 20),
    discount=beta,
    terminal_reward=lambda state: A + B * torch.log(state["wealth"]),
)

policy, value = solve(
    problem,
    state_grid={"wealth": WarpedGrid(n=128)},
    action_grid={"consume": RegularGrid(n=500)},
    solver=BackwardInduction(n_quad=7),
)

# Optimal consumption rate at any wealth ≈ 1 - β = 0.040
w = torch.tensor([2.0, 10.0, 25.0, 50.0])
policy({"wealth": w}, t=10)["consume"] / w
# → tensor([0.0401, 0.0401, 0.0401, 0.0401])  (closed-form: 0.04 at every point)

Reward is whatever scalar callable matches your problem: utility maximization, cost minimization, profit, payoff. Sign convention: bellgrid maximizes — negate costs.

Examples

Eight canonical problems, each validated against an analytical or numerical reference. Open the .ipynb files in JupyterLab or view them on GitHub.

Notebook Problem Validates against
01_merton Log-utility Merton consumption-portfolio Closed form V = A + B log w, c/w = 1 − β
02_carroll_deaton CRRA lifecycle savings with a borrowing constraint Endogenous Grid Method (Carroll 2006)
03_american_option American put on GBM CRR binomial tree (n=2000), agreement within ~1e-4
04_lqg 2-D linear-quadratic-Gaussian control Discrete-time Riccati recursion
05_two_asset_merton 2-asset Merton with correlated returns (MultivariateNormal) Numerical FOC for the optimal portfolio share
06_regime_switching_option American put under regime-switching vol (MarkovChain) Bracketed by constant-vol references at σ_low, σ_high, σ_stationary
07_retirement_decision Lifecycle work vs retire decision (DiscreteState, irreversible) Qualitative — boundary falls with age, accumulate → retire → decumulate dynamics
08_jump_diffusion_option American put under Merton (1976) jump-diffusion (Jump + Normal, multi-shock) European case validated against Merton 1976 series expansion (agreement within ~1e-3); American case shows the jump premium and lower exercise boundary

Each notebook opens with the problem statement and equations, then runs bellgrid against the reference side-by-side.

What's built

  • States: ContinuousState (with optional asinh / log warp), DiscreteState, MarkovChain (any number per problem; cost is additive in chains).
  • Actions: ContinuousAction (with optional state-dependent bounds), DiscreteAction.
  • Shocks: Normal, Lognormal, MultivariateNormal (Gauss-Hermite); Uniform on [low, high] (Gauss-Legendre); Categorical with finite support (exact); Jump (Bernoulli-approximated Poisson with Normal log-magnitudes). Multiple independent shocks per problem are supported via tensor-product quadrature.
  • Grids: RegularGrid, WarpedGrid.
  • Solvers: BackwardInduction for finite-horizon problems, PolicyIteration (value iteration under the hood) for infinite-horizon stationary problems. CPU or CUDA, JIT-compiled K-D multilinear interpolation.
  • Simulator: simulate() shares the user's transition and reward with the solver, so they can't drift apart.

Planned

  • Implicit differentiation of policy / value wrt model parameters.
  • Local action search instead of grid enumeration (for problems with many continuous actions).

When to use bellgrid (vs. RL)

You have (or can write) a transition model, state is roughly 1–6 continuous dims plus discrete, and you need a policy that is correct across the entire support — including tails the agent rarely visits.

bellgrid RL
State dim sweet spot 1–6 continuous + discrete thousands
Correctness Exact across the full grid Approximate, on-distribution
Tail / edge-case behavior By construction Only if explored in training
Constraints / kinks First-class Hard to encode
Off-policy what-ifs Cheap recompute Full retrain
You don't have a model Doesn't apply Where RL wins

Longer version with concrete borderline cases in docs/when_to_use.md; full API surface in docs/api.md.

License

MIT.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bellgrid-0.1.0a2.tar.gz (1.9 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bellgrid-0.1.0a2-py3-none-any.whl (40.1 kB view details)

Uploaded Python 3

File details

Details for the file bellgrid-0.1.0a2.tar.gz.

File metadata

  • Download URL: bellgrid-0.1.0a2.tar.gz
  • Upload date:
  • Size: 1.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.14

File hashes

Hashes for bellgrid-0.1.0a2.tar.gz
Algorithm Hash digest
SHA256 6a84dd0fcc653db1f6a81404ebaec52640567e3852b2867703e6df4b810d4ca1
MD5 4ba94ed5cd41d09b9a1273964450a9c3
BLAKE2b-256 dddddcefc67fb227e710e46332198cc6e972d8b32ea34678b4e32cd71b5d1512

See more details on using hashes here.

File details

Details for the file bellgrid-0.1.0a2-py3-none-any.whl.

File metadata

  • Download URL: bellgrid-0.1.0a2-py3-none-any.whl
  • Upload date:
  • Size: 40.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.14

File hashes

Hashes for bellgrid-0.1.0a2-py3-none-any.whl
Algorithm Hash digest
SHA256 35a3fb8d6252f92003acd8e3e3dd95da313b35488919d25902d87550aeaeee2e
MD5 4efb302df7cb641de16678b962cac8ca
BLAKE2b-256 47ac4a44a8f493c6664fbb75600e1cc15ca9c81f6eb4a804dee5a869d7f3e8a7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page