Skip to main content

JAX-based transition probability computation for multi-state models

Project description

jact

JAX-based transition probability computation for multi-state models with duration-dependent transition intensities.

What is jact?

jact computes transition probabilities in semi-Markov multi-state models. It takes fitted intensity models — parametric functions, GLMs, neural networks, or any JIT-compatible callable — and produces transition probabilities for thousands of individuals in a single vectorized pass on GPU. Computations are optimized for JIT-compiled GPU execution.

Quick example

import jax.numpy as jnp
import jact

# Define the state space
state_space = jact.StateSpace(
    states=["healthy", "disabled", "dead"],
    transitions=[
        ("healthy", "disabled"),
        ("healthy", "dead"),
        ("disabled", "dead"),
    ],
)

# Build a model with intensity functions
model = state_space.build(
    transitions={
        ("healthy", "disabled"): onset_fn,
        ("healthy", "dead"): mortality_fn,
        ("disabled", "dead"): disabled_mort_fn,
    }
)

# Compute transition probabilities for 1000 individuals
ages = jnp.linspace(30, 80, 1_000)
result = model.solve(initial="healthy", horizon=30, steps_per_unit=12, age=ages)

Key features

  • Plug in any model: Gompertz, GLM, neural network — anything that's JIT-compatible.
  • Swap and compare: Same StateSpace, different intensity models. Experiment easily.
  • Compute only what's needed: The solver reduces to states reachable from the initial state.
  • Exact seeded starts: Initial point masses preserve per-individual starting duration d_0 exactly.
  • Batch-first: Designed for 100K+ individuals in a single pass.

Documentation

See the documentation index for the public documentation set. For the full API contract, use the API specification. For a runnable walkthrough of the main workflow, see the example notebook.

Namespace

The recommended user API is the top-level jact namespace: jact.StateSpace, jact.InitialDistribution, jact.solve, and the cashflow declarations such as jact.StateRate and jact.Total. Advanced callback state objects remain available from submodules, for example jact.callbacks.PointMass and jact.model.ReducedModel.

Installation

pip install jax jaxlib
pip install jact

For local development from this repository:

pip install -e '.[dev]'
pytest

The package uses a src/ layout, so editable install is the intended local workflow.

To run the example notebook with plotting support from a local checkout:

pip install -e '.[dev,notebook]'

Release checks

Before cutting a PyPI release:

rm -rf build dist src/*.egg-info
python -m build --no-isolation
python -m twine check dist/*
pytest -q

The tag-driven publish flow is documented in RELEASING.md.

Requirements

  • Python >= 3.10
  • JAX >= 0.4

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jact-0.1.2.tar.gz (29.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jact-0.1.2-py3-none-any.whl (30.1 kB view details)

Uploaded Python 3

File details

Details for the file jact-0.1.2.tar.gz.

File metadata

  • Download URL: jact-0.1.2.tar.gz
  • Upload date:
  • Size: 29.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jact-0.1.2.tar.gz
Algorithm Hash digest
SHA256 54913682b92d980a7a024fddfe3fe86b355be3bc9860f2b143415a2827460a8e
MD5 cf7cb0c77d6c9b9e4be2b2a7d784e6c3
BLAKE2b-256 4db73d0712b71550534d018a13abf55163c5cebfa3375f271efbb4d9dc108d88

See more details on using hashes here.

Provenance

The following attestation bundles were made for jact-0.1.2.tar.gz:

Publisher: publish.yml on stojl/jact

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jact-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: jact-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 30.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jact-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 bdc26b97554b3c000f4b89e62eb5dcc1bf7b682be3753dc15f2b6b906a283891
MD5 cb563061b8ea438cd45c852805db3569
BLAKE2b-256 8d299b7b1c0f1e820700c154aa89e0815b4831b3f3f43f7752c59642dd48c57f

See more details on using hashes here.

Provenance

The following attestation bundles were made for jact-0.1.2-py3-none-any.whl:

Publisher: publish.yml on stojl/jact

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page