Skip to main content

JAX-based transition probability computation for multi-state models

Project description

jact

JAX-based transition probability computation for multi-state models with duration-dependent transition intensities.

What is jact?

jact computes transition probabilities in semi-Markov multi-state models. It takes fitted intensity models — parametric functions, GLMs, neural networks, or any JIT-compatible callable — and produces transition probabilities for thousands of individuals in a single vectorized pass on GPU. Computations are optimized for JIT-compiled GPU execution.

Quick example

import jax.numpy as jnp
import jact

# Define the state space
state_space = jact.StateSpace(
    states=["healthy", "disabled", "dead"],
    transitions=[
        ("healthy", "disabled"),
        ("healthy", "dead"),
        ("disabled", "dead"),
    ],
)

# Build a model with intensity functions
model = state_space.build(
    transitions={
        ("healthy", "disabled"): onset_fn,
        ("healthy", "dead"): mortality_fn,
        ("disabled", "dead"): disabled_mort_fn,
    }
)

# Compute transition probabilities for 1000 individuals
ages = jnp.linspace(30, 80, 1_000)
result = model.solve(initial="healthy", horizon=30, steps_per_unit=12, age=ages)

Key features

  • Plug in any model: Gompertz, GLM, neural network — anything that's JIT-compatible.
  • Swap and compare: Same StateSpace, different intensity models. Experiment easily.
  • Compute only what's needed: The solver reduces to states reachable from the initial state.
  • Exact seeded starts: Initial point masses preserve per-individual starting duration d_0 exactly.
  • Batch-first: Designed for 100K+ individuals in a single pass.

Documentation

See the documentation index for the public documentation set. For the full API contract, use the API specification. For a runnable walkthrough of the main workflow, see the example notebook.

Namespace

The recommended user API is the top-level jact namespace: jact.StateSpace, jact.InitialDistribution, jact.solve, and the cashflow declarations such as jact.StateRate and jact.Total. Advanced callback state objects remain available from submodules, for example jact.callbacks.PointMass and jact.model.ReducedModel.

Installation

pip install jax jaxlib
pip install jact

For local development from this repository:

pip install -e '.[dev]'
pytest

The package uses a src/ layout, so editable install is the intended local workflow.

To run the example notebook with plotting support from a local checkout:

pip install -e '.[dev,notebook]'

Release checks

Before cutting a PyPI release:

rm -rf build dist src/*.egg-info
python -m build --no-isolation
python -m twine check dist/*
pytest tests/test_solver.py tests/test_initial_distribution_integration.py -q

The tag-driven publish flow is documented in RELEASING.md.

Requirements

  • Python >= 3.10
  • JAX >= 0.4

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jact-0.1.1.tar.gz (29.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jact-0.1.1-py3-none-any.whl (29.7 kB view details)

Uploaded Python 3

File details

Details for the file jact-0.1.1.tar.gz.

File metadata

  • Download URL: jact-0.1.1.tar.gz
  • Upload date:
  • Size: 29.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jact-0.1.1.tar.gz
Algorithm Hash digest
SHA256 76a0eae3a4ea7c9f95bbd98ccb825ecbb79b0fc9a1d3f42a999f850f85dc327b
MD5 17a11151bde3e89803d24d8275b04f9a
BLAKE2b-256 435359a19f7035cc70cdb2e76364ecf285c6298f94095900986d77f0568fd37a

See more details on using hashes here.

Provenance

The following attestation bundles were made for jact-0.1.1.tar.gz:

Publisher: publish.yml on stojl/jact

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jact-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: jact-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 29.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for jact-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5240c20580e0ab6239b8465986cdba5dd166f42a089912475f108e69848f4118
MD5 864dba1e13155b2ab61a80b8786d44a9
BLAKE2b-256 b34476c8b1c79a77c2d2e25e6a14ab4a6958c74c0f7d254173ff6e4404772293

See more details on using hashes here.

Provenance

The following attestation bundles were made for jact-0.1.1-py3-none-any.whl:

Publisher: publish.yml on stojl/jact

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page