Skip to main content

GPU-native backward-induction MDP solver for continuous-state stochastic dynamic programs.

Project description

bellgrid

GPU-native backward-induction for continuous-state stochastic dynamic programs.

bellgrid solves Bellman equations exactly (up to interpolation error) across the entire state space. It's opinionated about backward induction, vectorization, and constraints; unopinionated about your application domain. Composes K continuous states with discrete-state primitives (DiscreteState, MarkovChain) and any mix of continuous and discrete actions. Supports asinh/log-warped grids, scalar / multivariate Gauss-Hermite shock quadrature, and a JIT-compiled multilinear kernel that runs on CPU or CUDA.

Quick start

git clone https://github.com/tbb300/bellgrid && cd bellgrid
uv sync --extra examples

A minimal Merton consumption-portfolio (log utility, single risky asset, lognormal returns):

import math, torch
from bellgrid import Problem, ContinuousState, ContinuousAction, solve
from bellgrid.grids import WarpedGrid, RegularGrid
from bellgrid.shocks import Normal
from bellgrid.solvers import BackwardInduction

beta, mu, sigma = 0.96, 0.04, 0.15
# closed-form coefficients for V(w) = A + B log(w)
B = 1.0 / (1.0 - beta)
A = math.log(1 - beta) / (1 - beta) + (beta / (1 - beta) ** 2) * (math.log(beta) + mu)

def transition(state, action, shock, t):
    return {"wealth": (state["wealth"] - action["consume"]) * torch.exp(mu + sigma * shock["z"])}

def reward(state, action, shock, t):
    return torch.log(action["consume"])

problem = Problem(
    states=[ContinuousState("wealth", warp="asinh", range=(1e-3, 200.0))],
    actions=[ContinuousAction("consume", bounds=(1e-6, "wealth"))],
    transition=transition,
    reward=reward,
    shocks=[Normal("z", sigma=1.0)],
    horizon=range(0, 20),
    discount=beta,
    terminal_reward=lambda state: A + B * torch.log(state["wealth"]),
)

policy, value = solve(
    problem,
    state_grid={"wealth": WarpedGrid(n=128)},
    action_grid={"consume": RegularGrid(n=500)},
    solver=BackwardInduction(n_quad=7),
)

# Optimal consumption rate at any wealth ≈ 1 - β = 0.040
w = torch.tensor([2.0, 10.0, 25.0, 50.0])
policy({"wealth": w}, t=10)["consume"] / w
# → tensor([0.0401, 0.0401, 0.0401, 0.0401])  (closed-form: 0.04 at every point)

Reward is whatever scalar callable matches your problem: utility maximization, cost minimization, profit, payoff. Sign convention: bellgrid maximizes — negate costs.

Examples

Eight canonical problems, each validated against an analytical or numerical reference. Open the .ipynb files in JupyterLab or view them on GitHub.

Notebook Problem Validates against
01_merton Log-utility Merton consumption-portfolio Closed form V = A + B log w, c/w = 1 − β
02_carroll_deaton CRRA lifecycle savings with a borrowing constraint Endogenous Grid Method (Carroll 2006)
03_american_option American put on GBM CRR binomial tree (n=2000), agreement within ~1e-4
04_lqg 2-D linear-quadratic-Gaussian control Discrete-time Riccati recursion
05_two_asset_merton 2-asset Merton with correlated returns (MultivariateNormal) Numerical FOC for the optimal portfolio share
06_regime_switching_option American put under regime-switching vol (MarkovChain) Bracketed by constant-vol references at σ_low, σ_high, σ_stationary
07_retirement_decision Lifecycle work vs retire decision (DiscreteState, irreversible) Qualitative — boundary falls with age, accumulate → retire → decumulate dynamics
08_jump_diffusion_option American put under Merton (1976) jump-diffusion (Jump + Normal, multi-shock) European case validated against Merton 1976 series expansion (agreement within ~1e-3); American case shows the jump premium and lower exercise boundary

Each notebook opens with the problem statement and equations, then runs bellgrid against the reference side-by-side.

What's built

  • States: ContinuousState (with optional asinh / log warp), DiscreteState, MarkovChain.
  • Actions: ContinuousAction (with optional state-dependent bounds), DiscreteAction.
  • Shocks: Normal, Lognormal, MultivariateNormal, Jump (Bernoulli-approximated Poisson with Normal log-magnitudes) — all with Gauss-Hermite quadrature. Multiple independent shocks per problem are supported via tensor-product quadrature.
  • Grids: RegularGrid, WarpedGrid.
  • Solvers: BackwardInduction for finite-horizon problems, PolicyIteration (value iteration under the hood) for infinite-horizon stationary problems. CPU or CUDA, JIT-compiled K-D multilinear interpolation.
  • Simulator: simulate() shares the user's transition and reward with the solver, so they can't drift apart.

Planned

  • Implicit differentiation of policy / value wrt model parameters.
  • Local action search instead of grid enumeration (for problems with many continuous actions).

When to use bellgrid (vs. RL)

You have (or can write) a transition model, state is roughly 1–6 continuous dims plus discrete, and you need a policy that is correct across the entire support — including tails the agent rarely visits.

bellgrid RL
State dim sweet spot 1–6 continuous + discrete thousands
Correctness Exact across the full grid Approximate, on-distribution
Tail / edge-case behavior By construction Only if explored in training
Constraints / kinks First-class Hard to encode
Off-policy what-ifs Cheap recompute Full retrain
You don't have a model Doesn't apply Where RL wins

Longer version with concrete borderline cases in docs/when_to_use.md; full API surface in docs/api.md.

License

MIT.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bellgrid-0.1.0a0.tar.gz (1.8 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bellgrid-0.1.0a0-py3-none-any.whl (30.3 kB view details)

Uploaded Python 3

File details

Details for the file bellgrid-0.1.0a0.tar.gz.

File metadata

  • Download URL: bellgrid-0.1.0a0.tar.gz
  • Upload date:
  • Size: 1.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.14

File hashes

Hashes for bellgrid-0.1.0a0.tar.gz
Algorithm Hash digest
SHA256 055b99cd1e11120287f9c49a42f8f621753abf51e2d082585074f3bf31cd6656
MD5 2a72b6debd4c7d34a2eb1c1cabc9fafa
BLAKE2b-256 adf589d7fd38950ee887c04037697b7753a48e7f7a555ffef5d2e0680b5934d1

See more details on using hashes here.

File details

Details for the file bellgrid-0.1.0a0-py3-none-any.whl.

File metadata

  • Download URL: bellgrid-0.1.0a0-py3-none-any.whl
  • Upload date:
  • Size: 30.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.14

File hashes

Hashes for bellgrid-0.1.0a0-py3-none-any.whl
Algorithm Hash digest
SHA256 b3161ec8b5d7fe64349831639703cf50fa0e67dee379b70c7fae254840cab9cd
MD5 828badd216cdc2997d1ec11d5504ff9e
BLAKE2b-256 59d90edc36966243c6aebe07c476fa9f089b2bbb772b2e7626e2d40069c50608

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page