Skip to main content

Joint modeling with automatic differentiation

Project description

jmstate

jmstate is a Python package for nonlinear multi-state joint modeling of longitudinal and time-to-event data. Built on PyTorch, it enables flexible specification of regression and link functions — including neural networks — while still offering built-in parametric baseline hazards and utilities for inference and prediction.

The package implements the methodology from:

A General Framework for Joint Multi-State Models Félix Laplante & Christophe Ambroise (2025) — arXiv:2510.07128


Installation

pip install jmstate

Requirements: Python ≥ 3.10, PyTorch, scikit-learn, NumPy, Matplotlib, rich, tqdm.


Documentation

Full API reference and tutorials: jmstate documentation


The Model

jmstate fits a joint model that links a longitudinal biomarker process to multi-state event history through shared individual random effects.

Longitudinal sub-model

Individual observations follow

$$y_{ij} = h(t_{ij}, \psi_i) + \epsilon_{ij}, \qquad \epsilon_{ij} \sim \mathcal{N}(0, R)$$

where $h$ is a user-defined regression function (e.g. bi-exponential, logistic) and individual parameters are defined via

$$\psi_i = f(\gamma, X_i, b_i), \qquad b_i \sim \mathcal{N}(0, Q)$$

with $\gamma$ fixed population-level effects, $X_i$ covariates, and $b_i$ individual random effects.

Multi-state sub-model

Let $G = (V, E)$ be a directed graph, where $V$ denotes the set of states and $E \subseteq V \times V$ the set of admissible transitions. The graph encodes all possible paths of the multi-state process, allowing for competing, recurrent, or absorbing transitions. The hazard for a transition $k \to k'$ at time $t$ given entry time $t_0$ satisfies

$$\lambda^{k \to k'}(t_0, t) = \lambda_0^{k \to k'}(t_0, t) \exp\left( \alpha^{k \to k'} g^{k \to k'}(t, \psi_i) + \beta^{k \to k'} X_i \right),$$

where $\lambda_0^{k \to k'}$ are parametric baseline hazards, $g^{k \to k'}$ are link functions acting as a bridge between the longitudinal and the semi-Markov multi-state processes, and $\alpha^{k \to k'}$, $\beta^{k \to k'}$ are transition-specific coefficients.

The model supports arbitrary state graphs (recurrent, absorbing, monotone, etc.) under a semi-Markov assumption.

Estimation

Parameters are estimated by maximising the observed-data log-likelihood using the Fisher identity

$$\nabla_\theta \log \mathcal{L}(\theta; x) = \mathbb{E}{b \sim p(\cdot \mid x, \theta)}\left[ \nabla\theta \log \mathcal{L}(\theta; x, b) \right],$$

where $\mathcal{L}(\theta; x, b)$ is the complete likelihood of the data given the parameters and random effects.

This gradient is approximated via a Metropolis-Within-Gibbs MCMC sampler over the random effects, combined with a stochastic gradient ascent step. Convergence is monitored via an $R^2$-based stationarity test.


Quick Start

Step 1 — Define the model design

import torch
from jmstate.types import ModelDesign


# Individual parameters
def indiv_effects_fn(
    fixed: torch.Tensor, x: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
    return fixed * torch.exp(b)  # (..., n, q)


# PK function: bi-exponential biomarker
def pk_fn(t: torch.Tensor, indiv_params: torch.Tensor, D: float = 1.0):
    A, k, ka = indiv_params.chunk(3, dim=-1)
    conc = A * (torch.exp(-k * t) - torch.exp(-ka * t))
    return conc.unsqueeze(-1)


# PK integral function: bi-exponential cumulative link
def pk_integral_fn(t: torch.Tensor, indiv_params: torch.Tensor):
    A, k, ka = indiv_params.chunk(3, dim=-1)
    integral = A * (1 - torch.exp(-k * t)) / k - (1 - torch.exp(-ka * t)) / ka
    return integral.unsqueeze(-1)


# Define the model design
design = ModelDesign(
    indiv_effects_fn,
    regression_fn=pk_fn,
    link_fns={
        (1, 1): pk_integral_fn,
        (1, 2): pk_integral_fn,
    },
)

Step 2 — Set initial parameters

from jmstate.functions.base_hazards import Exponential
from jmstate.types import ModelParameters, PrecisionParameters

# Define simple initial parameters
params = ModelParameters(
    torch.ones(3),
    PrecisionParameters.from_covariance(torch.eye(3), "diag"),
    PrecisionParameters.from_covariance(torch.eye(1), "spherical"),
    {(1, 1): Exponential(1.0), (1, 2): Exponential(1.0)},
    {(1, 1): torch.zeros(1), (1, 2): torch.zeros(1)},
    {(1, 1): torch.zeros(1), (1, 2): torch.zeros(1)},
)

Step 3 — Prepare data

from jmstate.types import ModelData

data = ModelData(
    x,  # (n, p) covariate matrix
    t,  # (m,) or (n, m) measurement times; NaN-pad if variable
    y,  # (n, m, d) longitudinal observations; NaN-pad if variable
    trajectories,  # list[list[tuple[float, Any]]]
    c,  # (n, 1) right-censoring times
)

Each trajectory is a chronologically ordered list of (time, state) tuples representing the individual's event history.

Step 4 — Fit the model

import matplotlib.pyplot as plt
from jmstate import MultiStateJointModel

optimizer = torch.optim.Adam(params.parameters(), lr=0.1)
model = MultiStateJointModel(design, params, optimizer)

metrics = model.fit(data)

Step 5 — Print and plot the results

from jmstate.utils import plot_mcmc_diagnostics, plot_params_history

# Compute and print summary statistics (nullity Wald statistics, p-values, AIC, BIC, etc.)
model.compute_summary().summary()

# Plot parameter history (stochastic optimization)...
plot_params_history(model)
plt.show()

# ...and MCMC sampler diagnostics
plot_mcmc_diagnostics(model)
plt.show()

License

See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jmstate-0.17.3.tar.gz (1.6 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jmstate-0.17.3-py3-none-any.whl (41.5 kB view details)

Uploaded Python 3

File details

Details for the file jmstate-0.17.3.tar.gz.

File metadata

  • Download URL: jmstate-0.17.3.tar.gz
  • Upload date:
  • Size: 1.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.13

File hashes

Hashes for jmstate-0.17.3.tar.gz
Algorithm Hash digest
SHA256 aa7cf7dfec6a37cbc22624d8a75ab7dfefb7052bacc55f85ed416b95865b0161
MD5 4be0a4fb7dd2ee666f981f39a27603fd
BLAKE2b-256 1831d359201e32544daf9380c7e8949b38e7689bbd8c802e0b02e725a5939a2f

See more details on using hashes here.

File details

Details for the file jmstate-0.17.3-py3-none-any.whl.

File metadata

  • Download URL: jmstate-0.17.3-py3-none-any.whl
  • Upload date:
  • Size: 41.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.13

File hashes

Hashes for jmstate-0.17.3-py3-none-any.whl
Algorithm Hash digest
SHA256 faafc9962f01c912d9a6be0f109eeb96f0f5dbd6dfc180f1967bdcf105b5a7db
MD5 bb163678c0c2ac0e231e06a10eebceb0
BLAKE2b-256 b8201334eb85f31a50fb4f4d38722af312e070d35de71c4ff60a0bea3edccb47

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page