Skip to main content

JAX-based transition probability computation for multi-state models

Project description

jact

JAX-based transition probability and expected cashflow computation for multi-state models with duration-dependent transition intensities.

What is jact?

jact computes transition probabilities and expected cashflows in semi-Markov multi-state models. It takes fitted intensity models — parametric functions, GLMs, neural networks, or any JIT-compatible callable — and produces transition probabilities and cashflow streams for thousands of individuals in a single vectorized pass on GPU. Computations are optimized for JIT-compiled GPU execution.

Quick example

import jax.numpy as jnp
import jact

# Define the state space
state_space = jact.StateSpace(
    states=["healthy", "disabled", "dead"],
    transitions=[
        ("healthy", "disabled"),
        ("healthy", "dead"),
        ("disabled", "dead"),
    ],
)

# Build a model with intensity functions
model = state_space.build(
    transitions={
        ("healthy", "disabled"): onset_fn,
        ("healthy", "dead"): mortality_fn,
        ("disabled", "dead"): disabled_mort_fn,
    }
)

# Compute transition probabilities for 1000 individuals
ages = jnp.linspace(30, 80, 1_000)
result = model.solve(initial="healthy", horizon=30, steps_per_unit=12, age=ages)

Cashflow example

Using the same state_space, model, and ages as above:

import jax.numpy as jnp
import jact


def annual_premium(t, d, *, age):
    return jnp.full((age.shape[0], d.shape[-1]), -1_200.0)


def death_benefit(t, d, *, age):
    return jnp.full((age.shape[0], d.shape[-1]), 100_000.0)


cashflows = state_space.cashflows(
    {
        "premium": jact.cashflows.StateRate({"healthy": annual_premium}),
        "death_benefit": jact.cashflows.TransitionLump(
            {
                ("healthy", "dead"): death_benefit,
                ("disabled", "dead"): death_benefit,
            }
        ),
    }
)

result = model.solve(
    initial="healthy",
    horizon=30,
    steps_per_unit=12,
    record_every=12,
    probability=None,
    cashflows=cashflows,
    cashflow_views={
        "raw": jact.cashflows.Raw(),
        "pv_total": jact.cashflows.Total(
            weight=lambda t, **kwargs: jnp.exp(-0.03 * t),
            terminal=True,
        ),
    },
    age=ages,
)

premium_stream = result.cashflows["raw"]["premium"]
present_value = result.cashflows["pv_total"]

Key features

  • Plug in any model: Gompertz, GLM, neural network — anything that's JIT-compatible.
  • Swap and compare: Same StateSpace, different intensity models. Experiment easily.
  • Probabilities and cashflows together: Emit both in one fused solve, with solve-time cashflow views for grouping and valuation.
  • Compute only what's needed: The solver reduces to states reachable from the initial state.
  • Exact seeded starts: Initial point masses preserve per-individual starting duration d_0 exactly.
  • Batch-first: Designed for 100K+ individuals in a single pass.

Documentation

See the documentation index for the public documentation set. For the full API contract, use the API specification. For a runnable walkthrough of the main workflow, see the example notebook. For a fitting-to-solver workflow with neural-network intensities, see the fitted neural-network notebook.

Namespace

The top-level jact namespace exposes the core types: jact.StateSpace, jact.Model, jact.InitialDistribution, jact.ModelResult, and jact.solve. Domain types live under two submodules:

  • jact.cashflows for declarations and views (StateRate, TransitionLump, ScheduledEvent, Raw, Group, Total, ByState, ByKind).
  • jact.probability for output reducers (StateProbability, DensityProbability, Density, PointMass, MarginalComponents, Full).

Advanced inspection types stay in private modules — for example jact.probability.StateCarry and jact.model.ReducedModel.

Installation

pip install jax jaxlib
pip install jact

For local development from this repository:

pip install -e '.[dev]'
pytest

The package uses a src/ layout, so editable install is the intended local workflow.

To run the example notebook with plotting support from a local checkout:

pip install -e '.[dev,notebook]'

Release checks

Before cutting a PyPI release:

rm -rf build dist src/*.egg-info
python -m build --no-isolation
python -m twine check dist/*
pytest -q

The tag-driven publish flow is documented in RELEASING.md.

Requirements

  • Python >= 3.10
  • JAX >= 0.4

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jact-0.1.4.tar.gz (32.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jact-0.1.4-py3-none-any.whl (32.9 kB view details)

Uploaded Python 3

File details

Details for the file jact-0.1.4.tar.gz.

File metadata

  • Download URL: jact-0.1.4.tar.gz
  • Upload date:
  • Size: 32.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jact-0.1.4.tar.gz
Algorithm Hash digest
SHA256 34f926bf7db137f20594805c3bc2f159b29360dece55b3da4e37c18b55d05faa
MD5 90a22cc0cebad71e18beb06b5cea03b5
BLAKE2b-256 6254953379ba91458729dbee0059f6990edcf028333265132f04ff7de3c81d49

See more details on using hashes here.

Provenance

The following attestation bundles were made for jact-0.1.4.tar.gz:

Publisher: publish.yml on stojl/jact

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jact-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: jact-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 32.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jact-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 e1c879f45033eba6975ec00dfea07f5d65eeef5e3e86efcccea5b4d86da445ec
MD5 fae29ef6fa52f6c6349281a6770e35ab
BLAKE2b-256 f432c4b8a11eedeebea97d9a38a3f268439d9b6f3f5f95d36ec62e260709c07d

See more details on using hashes here.

Provenance

The following attestation bundles were made for jact-0.1.4-py3-none-any.whl:

Publisher: publish.yml on stojl/jact

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page