Skip to main content

JAX-based transition probability computation for multi-state models

Project description

jact

JAX-based transition probability and expected cashflow computation for multi-state models with duration-dependent transition intensities.

What is jact?

jact computes transition probabilities and expected cashflows in semi-Markov multi-state models. It takes fitted intensity models — parametric functions, GLMs, neural networks, or any JIT-compatible callable — and produces transition probabilities and cashflow streams for thousands of individuals in a single vectorized pass on GPU. Computations are optimized for JIT-compiled GPU execution.

Quick example

import jax.numpy as jnp
import jact

# Define the state space
state_space = jact.StateSpace(
    states=["healthy", "disabled", "dead"],
    transitions=[
        ("healthy", "disabled"),
        ("healthy", "dead"),
        ("disabled", "dead"),
    ],
)

# Build a model with intensity functions
model = state_space.build(
    transitions={
        ("healthy", "disabled"): onset_fn,
        ("healthy", "dead"): mortality_fn,
        ("disabled", "dead"): disabled_mort_fn,
    }
)

# Compute transition probabilities for 1000 individuals
ages = jnp.linspace(30, 80, 1_000)
result = model.solve(initial="healthy", horizon=30, steps_per_unit=12, age=ages)

Cashflow example

Using the same state_space, model, and ages as above:

import jax.numpy as jnp
import jact


def annual_premium(t, d, *, age):
    return jnp.full((age.shape[0], d.shape[-1]), -1_200.0)


def death_benefit(t, d, *, age):
    return jnp.full((age.shape[0], d.shape[-1]), 100_000.0)


cashflows = state_space.cashflows(
    {
        "premium": jact.StateRate({"healthy": annual_premium}),
        "death_benefit": jact.TransitionLump(
            {
                ("healthy", "dead"): death_benefit,
                ("disabled", "dead"): death_benefit,
            }
        ),
    }
)

result = model.solve(
    initial="healthy",
    horizon=30,
    steps_per_unit=12,
    record_every=12,
    probability=None,
    cashflows=cashflows,
    cashflow_views={
        "raw": jact.Raw(),
        "pv_total": jact.Total(
            weight=lambda t, **kwargs: jnp.exp(-0.03 * t),
            terminal=True,
        ),
    },
    age=ages,
)

premium_stream = result["cashflows"]["raw"]["premium"]
present_value = result["cashflows"]["pv_total"]

Key features

  • Plug in any model: Gompertz, GLM, neural network — anything that's JIT-compatible.
  • Swap and compare: Same StateSpace, different intensity models. Experiment easily.
  • Probabilities and cashflows together: Emit both in one fused solve, with solve-time cashflow views for grouping and valuation.
  • Compute only what's needed: The solver reduces to states reachable from the initial state.
  • Exact seeded starts: Initial point masses preserve per-individual starting duration d_0 exactly.
  • Batch-first: Designed for 100K+ individuals in a single pass.

Documentation

See the documentation index for the public documentation set. For the full API contract, use the API specification. For a runnable walkthrough of the main workflow, see the example notebook. For a fitting-to-solver workflow with neural-network intensities, see the fitted neural-network notebook.

Namespace

The recommended user API is the top-level jact namespace: jact.StateSpace, jact.InitialDistribution, jact.solve, and the cashflow declarations such as jact.StateRate and jact.Total. Advanced callback state objects remain available from submodules, for example jact.callbacks.PointMass and jact.model.ReducedModel.

Installation

pip install jax jaxlib
pip install jact

For local development from this repository:

pip install -e '.[dev]'
pytest

The package uses a src/ layout, so editable install is the intended local workflow.

To run the example notebook with plotting support from a local checkout:

pip install -e '.[dev,notebook]'

Release checks

Before cutting a PyPI release:

rm -rf build dist src/*.egg-info
python -m build --no-isolation
python -m twine check dist/*
pytest -q

The tag-driven publish flow is documented in RELEASING.md.

Requirements

  • Python >= 3.10
  • JAX >= 0.4

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jact-0.1.3.tar.gz (30.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jact-0.1.3-py3-none-any.whl (30.8 kB view details)

Uploaded Python 3

File details

Details for the file jact-0.1.3.tar.gz.

File metadata

  • Download URL: jact-0.1.3.tar.gz
  • Upload date:
  • Size: 30.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jact-0.1.3.tar.gz
Algorithm Hash digest
SHA256 a6930ab194c55bbaac5a649f8f2badde59a42dc47ff7f6e5f3548bb9500e9e39
MD5 460b2576cc9e38bcbe4c96074824dd3d
BLAKE2b-256 cd9be40645f0bc2a5254d5465cdd53865e69a43009081f2d75946e416ad3b3ab

See more details on using hashes here.

Provenance

The following attestation bundles were made for jact-0.1.3.tar.gz:

Publisher: publish.yml on stojl/jact

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jact-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: jact-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 30.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jact-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 c1cf779ab6511ef6245c7459c6b18f1fa5ecfe9df098a74967005757d31a7473
MD5 96607f3df51e931804140b806d795d91
BLAKE2b-256 cce00f9c648cc84e818cb200783057062d471c2f4a9c4c20698fb57425d23e8a

See more details on using hashes here.

Provenance

The following attestation bundles were made for jact-0.1.3-py3-none-any.whl:

Publisher: publish.yml on stojl/jact

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page