Skip to main content

Add your description here

Project description

HJB Solver

A powerful and efficient Python package for solving Hamilton-Jacobi-Bellman (HJB) equations using policy iteration. Built on JAX for high performance and JIT compilation.

Features

  • Easy to Use: Simple abstract base classes for quick implementation
  • High Performance: JAX-powered with JIT compilation and automatic differentiation
  • Flexible: Handles complex boundary conditions and endogenous boundaries
  • Comprehensive: Built-in Newton solver, boundary search, and solution analysis
  • Well-Tested: Production-ready with extensive examples

Installation

From PyPI (Recommended)

pip install hjb-solver

Requirements

  • Python 3.12+
  • JAX: High-performance numerical computing
  • JAXtyping: Type annotations for JAX arrays
  • NumPy: Numerical operations
  • Pandas: Data analysis and export

All dependencies are automatically installed with the package. See pyproject.toml for the complete dependency list.

Quick Start

Get up and running in 5 minutes with the BCW2011 liquidation problem from "A unified theory of tobin's q, corporate investment, financing, and risk management" by Bolton, Chen, and Wang (2011), Case I: Liquidation:

from hjb_solver import *

# 1. Define your model parameters
@struct.dataclass(frozen=True)
class Parameter(AbstractParameter):
    r: float = 0.06         # Risk-free rate
    delta: float = 0.1007   # Rate of depreciation  
    mu: float = 0.18        # Productivity drift
    sigma: float = 0.09     # Productivity volatility
    theta: float = 1.5      # Adjustment cost parameter
    lambda_: float = 0.01   # Cash carrying cost
    l: float = 0.9          # Liquidation value

# 2. Define your policy variables
class PolicyDict(AbstractPolicyDict):
    investment: Array       # Investment rate

# 3. Define boundary conditions with dependencies
@dataclass
class Boundary(AbstractBoundary[Parameter]):
    def independent_boundary(self):
        return {"v_left": self.p.l}  # Liquidation value at left boundary
  
    def dependent_boundary(self):
        if self.s_max is None:
            raise ValueError("s_max must be provided.")
  
        # Complex payout boundary calculation
        sqrt_term = (
            self.p.r + self.p.delta + (self.s_max + 1) / self.p.theta
        ) ** 2 - (2 / self.p.theta) * (
            self.p.mu + (self.p.r + self.p.delta - self.p.lambda_) * self.s_max
            + (self.s_max + 1) ** 2 / (2 * self.p.theta)
        )
  
        v_right = self.p.theta * (
            (self.p.r + self.p.delta + (self.s_max + 1) / self.p.theta)
            - sqrt_term ** 0.5
        )
  
        return {"v_right": v_right}, {"s_max"}  # v_right depends on s_max

# 4. Implement your solver with endogenous boundary
class Solver(AbstractSolver[Parameter, PolicyDict]):
    def initialize_policy(self):
        # Start with frictionless optimal investment
        inv_fb = (
            self.p.r + self.p.delta - (
                (self.p.r + self.p.delta) ** 2 
                - 2 * (self.p.mu - (self.p.r + self.p.delta)) / self.p.theta
            ) ** 0.5
        )
        return PolicyDict(investment=jnp.full_like(self.s, inv_fb))
  
    def update_policy(self, v, dv, d2v, s, p):
        # First-order condition: 1 + θi = (V - sV')/V'
        investment = (1 / p.theta) * (v / dv - s - 1)
        return PolicyDict(investment=investment)
  
    @staticmethod
    def hjb_residual(v, dv, d2v, s, policy, p):
        inv = policy["investment"]
  
        # HJB equation for firm value with investment and cash management
        capital_evolution = (inv - p.delta) * (v - s * dv)
        cash_flow = ((p.r - p.lambda_) * s + p.mu - inv - 0.5 * p.theta * inv**2) * dv
        uncertainty = 0.5 * p.sigma**2 * d2v
        discount = -p.r * v
  
        return capital_evolution + cash_flow + uncertainty + discount
  
    def bisection_boundary_error(self, solution):
        # Target: smooth pasting condition V''(s_max) = 0
        return solution.boundary_derivative.d2v_right

# 5. Solve with endogenous boundary optimization
parameter = Parameter()
boundary = Boundary(p=parameter, s_min=0.0, s_max=0.22)  # Initial guess for payout boundary
solver = Solver(p=parameter, boundary=boundary, guess_policy=True)

# Find optimal payout boundary using bisection search
info = solver.bisection_search(
    boundary_name="s_max",
    low=0.1,
    high=0.3,
)

pp(info)

# 6. Analyze results
solution = solver.solution
pp(solution.df)

That's it! You've solved a complete corporate finance problem with endogenous boundaries.

Core Concepts

The HJB Solver framework is built around four abstract base classes.

1. AbstractParameter - Model Parameters

Purpose: Store all model parameters in a JAX-compatible, immutable container.

This class holds all the economic parameters of your model in a structured, type-safe way. The immutable design ensures parameters cannot be accidentally modified during solution, while JAX compatibility enables efficient numerical computation.

Key Requirements:

  • Use @struct.dataclass(frozen=True) decorator for immutability
  • All fields must be JAX-compatible types (float, int, Array)
  • Inherit from AbstractParameter
@struct.dataclass(frozen=True)
class Parameter(AbstractParameter):
    # Economic fundamentals
    discount_rate: float = ...      # Time preference ρ
    risk_aversion: float = ...      # Relative risk aversion γ
  
    # Technology parameters
    productivity: float = ...       # Production efficiency μ
    depreciation: float = ...       # Capital depreciation δ
  
    # Financial frictions
    borrowing_cost: float = ...     # Credit spread λ
    adjustment_cost: float = ...    # Investment adjustment θ
    # ... other parameters

2. AbstractPolicyDict - Policy Variables

Purpose: Define the policy variables (controls) in your HJB problem.

This class serves as a container for all decision variables that the agent optimizes. Each policy represents the optimal choice as a function of the state variable across the entire grid.

Key Requirements:

  • Inherit from AbstractPolicyDict
  • Each policy is an Array representing values across the state grid
  • Use descriptive names matching your economic model

Design Principles:

  • Descriptive naming: Use clear economic terms (e.g., consumption, investment, hedging)
  • Complete specification: Include all control variables in your model
class PolicyDict(AbstractPolicyDict):
    control_var1: Array      # First control variable
    control_var2: Array      # Second control variable
    # ... other policy variables

3. AbstractBoundary - Boundary Conditions

Purpose: Specify value function boundaries, with support for complex interdependent boundaries.

Boundary conditions are crucial for well-posed HJB problems. This class handles both simple fixed boundaries and complex cases where boundary values depend on endogenous variables that must be solved jointly with the value function.

Implementation Patterns:

Simple Pattern (Fixed Boundaries)

Use this when all boundary values are known constants:

@dataclass
class Boundary(AbstractBoundary[Parameter]):
    pass

# Usage with all boundaries specified
boundary = Boundary(p=params, s_min=..., s_max=..., v_left=..., v_right=...)

Advanced Pattern (Computed Boundaries)

Use this when boundary values depend on parameters or other boundaries:

@dataclass  
class Boundary(AbstractBoundary[Parameter]):
    def independent_boundary(self):
        # Specify known boundaries first
        return {"s_min": ..., "v_left": ...}
  
    def dependent_boundary(self):
        # Compute boundaries that depend on others
        terminal_value = ...  # Complex computation based on self.s_max
        return {"v_right": terminal_value}, {"s_max"}

Method Details:

  • independent_boundary(): Returns dictionary of boundaries that can be set independently
  • dependent_boundary(): Returns tuple of (boundary_dict, dependency_set) for computed boundaries
  • Dependencies in the second tuple element must be provided when creating the boundary instance

4. AbstractSolver - Solution Algorithm

Purpose: Implement the economic logic of your HJB equation through three core methods that define the dynamic programming problem.

The AbstractSolver class serves as the heart of the framework, where you encode the economic behavior of your model. It handles the iterative solution process through policy iteration, automatically managing value function updates, policy optimization, and convergence checking.

Core Methods (Required)

1. initialize_policy() -> PolicyDict

Provides the initial guess for all policy variables to start the iteration process. A good initial guess can significantly improve convergence speed and help avoid local optima.

Common strategies include:

  • Frictionless solutions: Start with the optimal policy ignoring all frictions
  • Steady-state values: Use long-run equilibrium values as constant policies
  • Simple heuristics: Linear interpolation between boundary values

2. update_policy(v, dv, d2v, s, p) -> PolicyDict

Implements the first-order conditions (FOCs) of your optimization problem. Given the current value function and its derivatives, this method computes the optimal policy at each grid point.

The method receives:

  • v: Value function V(s)
  • dv: First derivative V'(s)
  • d2v: Second derivative V''(s)
  • s: State variable grid points
  • p: Model parameters

Two implementation approaches:

  1. Direct computation: When FOCs can be solved analytically for the control variables
  2. Newton iteration: For complex policies that require numerical solution of FOCs using self.solve_policy()

3. hjb_residual(v, dv, d2v, s, policy, p) -> Array (static method)

Defines the Hamilton-Jacobi-Bellman equation that your value function must satisfy. This method computes the residual of the HJB equation - when the solution is correct, this residual should be zero everywhere.

The HJB equation typically has the form:

ρV(s) = max[u(c,s) + μ(c,s)V'(s) + (1/2)σ²(c,s)V''(s)]

Your implementation should return: LHS - RHS of the HJB equation.

Advanced Methods (Optional)

4. update_boundary(solution) -> (dict, float)

For problems with endogenous boundaries that emerge naturally from the economic model. These boundaries are determined by economic equilibrium conditions rather than external constraints.

When to use: Problems where the boundary location is itself an economic decision, such as:

  • Optimal liquidation/payout boundaries in corporate finance

Implementation principle: Use economic conditions to determine where boundaries should be located. Common approaches include finding points where marginal conditions are satisfied (e.g., where the marginal value of cash equals its outside value).

5. bisection_boundary_error(solution) -> float

For problems where the boundary location must satisfy a specific mathematical condition that can be expressed as a target function equaling zero.

When to use: Problems with well-defined optimality conditions at boundaries, such as:

  • Smooth pasting conditions: V''(boundary) = 0
  • Value matching with known target: V(boundary) = target_value
  • Marginal conditions: V'(boundary) = known_slope

Implementation principle: Define a function that measures how far the current solution deviates from the target condition. The bisection algorithm will find the boundary value that makes this function equal zero.

Critical: Return Value Sign Convention

The sign of the returned error determines the bisection search direction:

  • Positive error: Current boundary value is too high, algorithm will decrease the boundary value
  • Negative error: Current boundary value is too low, algorithm will increase the boundary value

Example: For smooth pasting condition V''(s_max) = 0 with a concave value function:

  • If V''(s_max) < 0: The current s_max is too small, need to increase it → return negative error
  • If V''(s_max) > 0: The current s_max is too large, need to decrease it → return positive error
  • Therefore: return solution.boundary_derivative.d2v_right (same sign as the condition)

Implementation Framework

class Solver(AbstractSolver[Parameter, PolicyDict]):
  
    def initialize_policy(self):
        # Provide initial guess for all policy variables
        return PolicyDict(
            control_var1=jnp.full_like(self.s, initial_guess1),
            control_var2=jnp.full_like(self.s, initial_guess2),
            # ... other controls
        )
  
    def update_policy(self, v, dv, d2v, s, p):
        # Method 1: Direct analytical solution
        control1 = ...  # Solve FOC analytically
  
        # Method 2: Newton iteration for complex FOCs
        def control2_foc(control_val, v, dv, d2v, s, other_policy):
            # Define FOC residual that should equal zero
            return ...  # FOC expression
  
        other_policy = {"control_var1": control1}
        control2, convergence_error = self.solve_policy(
            "control_var2", control2_foc, v, dv, d2v, s, other_policy
        )
  
        return PolicyDict(control_var1=control1, control_var2=control2)
  
    @staticmethod
    def hjb_residual(v, dv, d2v, s, policy, p):
        control1 = policy["control_var1"]
        control2 = policy["control_var2"]
  
        # Build HJB equation components
        drift_term = ...      # μ(c,s)V'(s)
        diffusion_term = ...  # (1/2)σ²(c,s)V''(s)
        utility_term = ...    # u(c,s)
        discount_term = ...   # -ρV(s)
  
        return drift_term + diffusion_term + utility_term + discount_term
  
    # Optional: For endogenous boundary problems
    def update_boundary(self, solution):
        # Find boundary based on economic conditions
        # Example: Find point where marginal value condition is satisfied
        new_boundary_value = ...  # Economic calculation
        error = ...  # Measure of convergence
        return {"boundary_name": new_boundary_value}, error
  
    # Optional: For bisection boundary search
    def bisection_boundary_error(self, solution):
        # Return target function that should equal zero
        # Sign determines search direction: 
        # positive = decrease boundary, negative = increase boundary
        return solution.boundary_derivative.d2v_right

Implementation Workflow

  1. Define your economic model: Create Parameter and PolicyDict classes
  2. Set up boundaries: Create Boundary class (simple or advanced)
  3. Implement solution logic: Create Solver class with the three core methods
  4. Solve: Instantiate and call solver.solve()
  5. Analyze: Use the returned solution object for analysis and visualization

API Reference

Solver Configuration

The solver constructor provides extensive configuration options to fine-tune performance and convergence behavior for your specific problem:

solver = Solver(
    p=parameters,                  # Your Parameter instance
    boundary=boundary,             # Your Boundary instance
  
    # Grid settings
    number=1000,                   # Number of grid points
    interval=None,                 # Alternative: fixed spacing
  
    # Policy iteration settings
    policy_max_iter=100,          # Maximum policy iterations
    policy_tol=1e-8,              # Convergence tolerance
    policy_patience=15,           # Early stopping patience
  
    # Value function update settings  
    value_max_iter=50,            # Max iterations per value update
    value_tol=1e-8,               # Value function tolerance
    value_patience=10,            # Value update patience
  
    # Newton solver settings (for complex policies)
    newton_max_iter=10,           # Max Newton iterations
    newton_tol=1e-12,             # Newton tolerance
  
    # Boundary search settings
    boundary_max_iter=30,         # Max boundary search iterations
    boundary_tol=1e-5,            # Boundary convergence tolerance
    boundary_patience=5,          # Boundary search patience
  
    # Initial policy
    guess_policy=True             # Use initialize_policy() as starting point
)

Core Solution Methods

1. solve() - Basic HJB Solution

Solves the HJB equation with fixed boundary conditions using policy iteration. This is the core method for most problems where all boundary values are known.

When to use: Problems with exogenous boundaries or when boundary optimization is not needed.

Returns: Info object containing convergence diagnostics and solution metadata.

info = solver.solve()

Info object attributes:

  • converged: bool - Whether algorithm converged
  • iterations: int - Number of policy iterations taken
  • final_error: float - Final convergence error
  • time: float - Solution time in seconds

2. search_boundary() - Endogenous Boundary Optimization

Optimizes boundary conditions by iterating between HJB solution and boundary updates. Use this when boundary values are determined endogenously by economic conditions.

When to use: Problems where boundary conditions depend on the solution itself (e.g., optimal stopping problems, free boundary problems).

Requirements: Must implement update_boundary(solution) -> (dict, float) in your Solver class.

info = solver.search_boundary(
    max_iter=20,                  # Maximum boundary search iterations
    tol=1e-4,                     # Boundary convergence tolerance
    patience=5                    # Early stopping patience
)

Implementation example:

def update_boundary(self, solution):
    # Example: update based on smooth pasting condition
    new_s_max = find_smooth_pasting_point(solution)
    new_v_right = compute_terminal_value(new_s_max, solution)
  
    boundary_updates = {"s_max": new_s_max, "v_right": new_v_right}
    error = abs(new_s_max - self.boundary.s_max)  # convergence measure
  
    return boundary_updates, error

3. bisection_search() - Targeted Boundary Search

Uses bisection method to find boundary values that satisfy a specific target condition. This is more robust than search_boundary() when you have a well-defined target function.

When to use: When you have a specific condition that the optimal boundary must satisfy (e.g., smooth pasting V''(s_max) = 0).

Requirements: Must implement bisection_boundary_error(solution) -> float in your Solver class.

info = solver.bisection_search(
    boundary_name="s_max",        # Which boundary parameter to search over
    low=0.1,                      # Lower bound for search
    high=2.0,                     # Upper bound for search  
    tol=1e-4,                     # Bisection tolerance
    max_iter=20,                  # Maximum bisection iterations
    patience=5                    # Patience for early stopping
)

Implementation example:

def bisection_boundary_error(self, solution):
    # Example: smooth pasting condition V''(s_max) = 0
    # For concave value function, return the same sign as the condition
    # - If d2v_right < 0: s_max too small, need increase → negative error
    # - If d2v_right > 0: s_max too large, need decrease → positive error
    return solution.boundary_derivative.d2v_right

Working with Solutions

All solver methods return an Info object with convergence diagnostics and automatically update solver.solution with the complete results. The solution object provides multiple ways to access and analyze your results.

Solution Components

# Solve and get convergence info
info = solver.solve()
solution = solver.solution

# Basic solution arrays
state_grid = solution.s           # State variable values
value_function = solution.v       # V(s)
marginal_value = solution.dv      # V'(s)  
curvature = solution.d2v         # V''(s)

# Policy functions (depends on your PolicyDict)
optimal_policies = solution.policy
investment = solution.policy["investment"]
consumption = solution.policy["consumption"]

# Grid and boundary information
grid_points = solution.number     # Number of grid points
boundary_values = solution.boundary.get_boundary()  # (s_min, s_max, v_left, v_right)

Analysis and Interpolation

The solution provides tools for detailed analysis and evaluation at arbitrary points:

# DataFrame interface for comprehensive analysis
df = solution.df
pp(df)

Solution Export

The solution object provides convenient export methods for sharing results and further analysis in other tools:

# Export in multiple formats
solution.save("results.feather")          # Fast binary format (recommended)
solution.save("results.csv")              # CSV for interoperability  
solution.save("report.xlsx")              # Excel for sharing

Examples

The src/examples/ folder contains complete implementations demonstrating various features:

Available Examples

1. BCW2011Liquidation.py

From: Bolton, Chen, and Wang (2011) - "A unified theory of tobin's q, corporate investment, financing, and risk management", Case I: Liquidation

Features:

  • Simple single-control problem (investment only)
  • Endogenous boundary optimization using bisection search
  • Dependent boundary conditions with complex calculations
  • Smooth pasting conditions

Use Case: Perfect for learning the framework basics and understanding endogenous boundary problems.

2. BCW2011Hedging.py

From: Bolton, Chen, and Wang (2011) - Same paper, full model with hedging

Features:

  • Multi-control problem (investment + hedging)
  • Complex constraint handling (margin requirements)
  • Multiple boundary regions with switching
  • Advanced policy updates with interior/boundary solutions

Use Case: Advanced example showing multiple controls and constraint handling.

Running Examples

# Run any example directly
python src/examples/BCW2011Liquidation.py

# Or import and modify
from src.examples.BCW2011Liquidation import Parameter, Solver, Boundary

# Modify parameters and re-solve
custom_params = Parameter()
boundary = Boundary(s_min=0.0, s_max=0.25)
solver = Solver(p=custom_params, boundary=boundary)

info = solver.bisection_search("s_max", low=0.1, high=0.4)

License

This project is licensed under the MIT License.

Citation

If you use this package in your research, please cite:

@software{hjb_solver,
  title={HJB Solver: A Python Package for Solving Hamilton-Jacobi-Bellman Equations},
  author={Haotian Deng},
  year={2025},
  url={https://github.com/Su-luoya/hjb-solver}
}

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Support

For questions and support:

  • Open an issue on GitHub
  • Check the examples in src/examples/
  • Review the API documentation

Happy solving! 🚀

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hjb_solver-0.1.0.tar.gz (39.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hjb_solver-0.1.0-py3-none-any.whl (37.3 kB view details)

Uploaded Python 3

File details

Details for the file hjb_solver-0.1.0.tar.gz.

File metadata

  • Download URL: hjb_solver-0.1.0.tar.gz
  • Upload date:
  • Size: 39.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.4

File hashes

Hashes for hjb_solver-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1a5bd97a735f95ed4bd7111d56f669e55d04e8ea73b1e96501096e6a55c64f08
MD5 5da9a6ed051ee3817df8eb1932c1cd4b
BLAKE2b-256 270897e32eedbabed902336017b7c0beaef5cce8e3fff9235ec1332049953662

See more details on using hashes here.

File details

Details for the file hjb_solver-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: hjb_solver-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 37.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.4

File hashes

Hashes for hjb_solver-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8eb222df0066a3785247f36d7338933bfab22fe90eb55297f5f46294117eeef9
MD5 a2e128dff098fc0118755244ca75f93a
BLAKE2b-256 f792bf89e854ef13329c6a8365fc6a2e0a32fb5d82bfa70be2143876b272aeca

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page