Skip to main content

Joint modeling with automatic differentiation

Project description

jmstate

jmstate is a Python package for nonlinear multi-state joint modeling of longitudinal and time-to-event data. Built on PyTorch, it enables flexible specification of regression and link functions — including neural networks — while still offering built-in parametric baseline hazards and utilities for inference and prediction.

The package implements the methodology from:

A General Framework for Joint Multi-State Models Félix Laplante & Christophe Ambroise (2025) — arXiv:2510.07128


Installation

pip install jmstate

Requirements: Python ≥ 3.10, PyTorch, scikit-learn, NumPy, Matplotlib, rich, tqdm.


Documentation

Full API reference and tutorials: jmstate documentation


The Model

jmstate fits a joint model that links a longitudinal biomarker process to multi-state event history through shared individual random effects.

Longitudinal sub-model

Individual observations follow

$$y_{ij} = h(t_{ij}, \psi_i) + \epsilon_{ij}, \qquad \epsilon_{ij} \sim \mathcal{N}(0, R)$$

where $h$ is a user-defined regression function (e.g. bi-exponential, logistic) and individual parameters are defined via

$$\psi_i = f(\gamma, X_i, b_i), \qquad b_i \sim \mathcal{N}(0, Q)$$

with $\gamma$ fixed population-level effects, $X_i$ covariates, and $b_i$ individual random effects.

Multi-state sub-model

Let $G = (V, E)$ be a directed graph, where $V$ denotes the set of states and $E \subseteq V \times V$ the set of admissible transitions. The graph encodes all possible paths of the multi-state process, allowing for competing, recurrent, or absorbing transitions. The hazard for a transition $k \to k'$ at time $t$ given entry time $t_0$ satisfies

$$\lambda^{k \to k'}(t_0, t) = \lambda_0^{k \to k'}(t_0, t) \exp\left( \alpha^{k \to k'} g^{k \to k'}(t, \psi_i) + \beta^{k \to k'} X_i \right),$$

where $\lambda_0^{k \to k'}$ are parametric baseline hazards, $g^{k \to k'}$ are link functions acting as a bridge between the longitudinal and the semi-Markov multi-state processes, and $\alpha^{k \to k'}$, $\beta^{k \to k'}$ are transition-specific coefficients.

The model supports arbitrary state graphs (recurrent, absorbing, monotone, etc.) under a semi-Markov assumption.

Estimation

Parameters are estimated by maximising the observed-data log-likelihood using the Fisher identity

$$\nabla_\theta \log \mathcal{L}(\theta; x) = \mathbb{E}{b \sim p(\cdot \mid x, \theta)}\left[ \nabla\theta \log \mathcal{L}(\theta; x, b) \right],$$

where $\mathcal{L}(\theta; x, b)$ is the complete likelihood of the data given the parameters and random effects.

This gradient is approximated via a Metropolis-Within-Gibbs MCMC sampler over the random effects, combined with a stochastic gradient ascent step. Convergence is monitored via an $R^2$-based stationarity test.


Quick Start

Step 1 — Define the model design

import torch
from jmstate.types import ModelDesign


# Individual parameters
def indiv_effects_fn(
    fixed: torch.Tensor, x: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
    return fixed * torch.exp(b)  # (..., n, q)


# PK function: bi-exponential biomarker
def pk_fn(t: torch.Tensor, indiv_params: torch.Tensor, D: float = 1.0):
    A, k, ka = indiv_params.chunk(3, dim=-1)
    conc = A * (torch.exp(-k * t) - torch.exp(-ka * t))
    return conc.unsqueeze(-1)


# PK integral function: bi-exponential cumulative link
def pk_integral_fn(t: torch.Tensor, indiv_params: torch.Tensor):
    A, k, ka = indiv_params.chunk(3, dim=-1)
    integral = A * (1 - torch.exp(-k * t)) / k - (1 - torch.exp(-ka * t)) / ka
    return integral.unsqueeze(-1)


# Define the model design
design = ModelDesign(
    indiv_effects_fn,
    regression_fn=pk_fn,
    link_fns={
        (1, 1): pk_integral_fn,
        (1, 2): pk_integral_fn,
    },
)

Step 2 — Set initial parameters

from jmstate.functions.base_hazards import Exponential
from jmstate.types import ModelParameters, PrecisionParameters

# Define simple initial parameters
params = ModelParameters(
    torch.ones(3),
    PrecisionParameters.from_covariance(torch.eye(3), "diag"),
    PrecisionParameters.from_covariance(torch.eye(1), "spherical"),
    {(1, 1): Exponential(1.0), (1, 2): Exponential(1.0)},
    {(1, 1): torch.zeros(1), (1, 2): torch.zeros(1)},
    {(1, 1): torch.zeros(1), (1, 2): torch.zeros(1)},
)

Step 3 — Prepare data

from jmstate.types import ModelData

data = ModelData(
    x,  # (n, p) covariate matrix
    t,  # (m,) or (n, m) measurement times; NaN-pad if variable
    y,  # (n, m, d) longitudinal observations; NaN-pad if variable
    trajectories,  # list[list[tuple[float, Any]]]
    c,  # (n, 1) right-censoring times
)

Each trajectory is a chronologically ordered list of (time, state) tuples representing the individual's event history.

Step 4 — Fit the model

import matplotlib.pyplot as plt
from jmstate import MultiStateJointModel

optimizer = torch.optim.Adam(params.parameters(), lr=0.1)
model = MultiStateJointModel(design, params, optimizer)

metrics = model.fit(data)

Step 5 — Print and plot the results

from jmstate.utils import plot_mcmc_diagnostics, plot_params_history

# Compute and print summary statistics (nullity Wald statistics, p-values, AIC, BIC, etc.)
model.compute_summary().summary()

# Plot parameter history (stochastic optimization)...
plot_params_history(model)
plt.show()

# ...and MCMC sampler diagnostics
plot_mcmc_diagnostics(model)
plt.show()

License

See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jmstate-0.17.4.tar.gz (1.6 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jmstate-0.17.4-py3-none-any.whl (41.5 kB view details)

Uploaded Python 3

File details

Details for the file jmstate-0.17.4.tar.gz.

File metadata

  • Download URL: jmstate-0.17.4.tar.gz
  • Upload date:
  • Size: 1.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.13

File hashes

Hashes for jmstate-0.17.4.tar.gz
Algorithm Hash digest
SHA256 d09c990ff6d4212d722900c4d3e0e6fe8842420bdc233d007c63578cdb6c3b71
MD5 d3dd8dd55c181f73a02088eba84e16ff
BLAKE2b-256 058c8e34a20a3fadf5c6afdc3ff410e4505440cb3fe47e25d83abea7b16524b1

See more details on using hashes here.

File details

Details for the file jmstate-0.17.4-py3-none-any.whl.

File metadata

  • Download URL: jmstate-0.17.4-py3-none-any.whl
  • Upload date:
  • Size: 41.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.13

File hashes

Hashes for jmstate-0.17.4-py3-none-any.whl
Algorithm Hash digest
SHA256 5d6add9a12f68fb820036756a1e2fe64086aa9ec8b4057ea9cfbe739618efd60
MD5 dbdfa26326c60fa4fcb2c17fed7976f7
BLAKE2b-256 77ca47adc5c3bf83610d5807f39243a234b3238f676264ea6d34656e4ca347ec

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page