Skip to main content

Joint modeling with automatic differentiation

Project description

jmstate

jmstate is a Python package for nonlinear multi-state joint modeling of longitudinal and time-to-event data. Built on PyTorch, it enables flexible specification of regression and link functions — including neural networks — while still offering built-in parametric baseline hazards and utilities for inference and prediction.

The package implements the methodology from:

A General Framework for Joint Multi-State Models Félix Laplante & Christophe Ambroise (2025) — arXiv:2510.07128


Installation

pip install jmstate

Requirements: Python ≥ 3.10, PyTorch, scikit-learn, NumPy, Matplotlib, rich, tqdm.


Documentation

Full API reference and tutorials: jmstate documentation


The Model

jmstate fits a joint model that links a longitudinal biomarker process to multi-state event history through shared individual random effects.

Longitudinal sub-model

Individual observations follow

$$y_{ij} = h(t_{ij}, \psi_i) + \epsilon_{ij}, \qquad \epsilon_{ij} \sim \mathcal{N}(0, R)$$

where $h$ is a user-defined regression function (e.g. bi-exponential, logistic) and individual parameters are defined via

$$\psi_i = f(\gamma, X_i, b_i), \qquad b_i \sim \mathcal{N}(0, Q)$$

with $\gamma$ fixed population-level effects, $X_i$ covariates, and $b_i$ individual random effects.

Multi-state sub-model

Let $G = (V, E)$ be a directed graph, where $V$ denotes the set of states and $E \subseteq V \times V$ the set of admissible transitions. The graph encodes all possible paths of the multi-state process, allowing for competing, recurrent, or absorbing transitions. The hazard for a transition $k \to k'$ at time $t$ given entry time $t_0$ satisfies

$$\lambda^{k \to k'}(t_0, t) = \lambda_0^{k \to k'}(t_0, t) \exp\left( \alpha^{k \to k'} g^{k \to k'}(t, \psi_i) + \beta^{k \to k'} X_i \right),$$

where $\lambda_0^{k \to k'}$ are parametric baseline hazards, $g^{k \to k'}$ are link functions acting as a bridge between the longitudinal and the semi-Markov multi-state processes, and $\alpha^{k \to k'}$, $\beta^{k \to k'}$ are transition-specific coefficients.

The model supports arbitrary state graphs (recurrent, absorbing, monotone, etc.) under a semi-Markov assumption.

Estimation

Parameters are estimated by maximising the observed-data log-likelihood using the Fisher identity

$$\nabla_\theta \log \mathcal{L}(\theta; x) = \mathbb{E}{b \sim p(\cdot \mid x, \theta)}\left[ \nabla\theta \log \mathcal{L}(\theta; x, b) \right],$$

where $\mathcal{L}(\theta; x, b)$ is the complete likelihood of the data given the parameters and random effects.

This gradient is approximated via a Metropolis-Within-Gibbs MCMC sampler over the random effects, combined with a stochastic gradient ascent step. Convergence is monitored via an $R^2$-based stationarity test.


Quick Start

Step 1 — Define the model design

import torch
from jmstate.types import ModelDesign


# Individual parameters
def indiv_effects_fn(
    fixed: torch.Tensor, x: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
    return fixed * torch.exp(b)  # (..., n, q)


# PK function: bi-exponential biomarker
def pk_fn(t: torch.Tensor, indiv_params: torch.Tensor, D: float = 1.0):
    A, k, ka = indiv_params.chunk(3, dim=-1)
    conc = A * (torch.exp(-k * t) - torch.exp(-ka * t))
    return conc.unsqueeze(-1)


# PK integral function: bi-exponential cumulative link
def pk_integral_fn(t: torch.Tensor, indiv_params: torch.Tensor):
    A, k, ka = indiv_params.chunk(3, dim=-1)
    integral = A * (1 - torch.exp(-k * t)) / k - (1 - torch.exp(-ka * t)) / ka
    return integral.unsqueeze(-1)


# Define the model design
design = ModelDesign(
    indiv_effects_fn,
    regression_fn=pk_fn,
    link_fns={
        (1, 1): pk_integral_fn,
        (1, 2): pk_integral_fn,
    },
)

Step 2 — Set initial parameters

from jmstate.functions.base_hazards import Exponential
from jmstate.types import ModelParameters, PrecisionParameters

# Define simple initial parameters
params = ModelParameters(
    torch.ones(3),
    PrecisionParameters.from_covariance(torch.eye(3), "diag"),
    PrecisionParameters.from_covariance(torch.eye(1), "spherical"),
    {(1, 1): Exponential(1.0), (1, 2): Exponential(1.0)},
    {(1, 1): torch.zeros(1), (1, 2): torch.zeros(1)},
    {(1, 1): torch.zeros(1), (1, 2): torch.zeros(1)},
)

Step 3 — Prepare data

from jmstate.types import ModelData

data = ModelData(
    x,  # (n, p) covariate matrix
    t,  # (m,) or (n, m) measurement times; NaN-pad if variable
    y,  # (n, m, d) longitudinal observations; NaN-pad if variable
    trajectories,  # list[list[tuple[float, Any]]]
    c,  # (n, 1) right-censoring times
)

Each trajectory is a chronologically ordered list of (time, state) tuples representing the individual's event history.

Step 4 — Fit the model

import matplotlib.pyplot as plt
from jmstate import MultiStateJointModel

optimizer = torch.optim.Adam(params.parameters(), lr=0.1)
model = MultiStateJointModel(design, params, optimizer)

metrics = model.fit(data)

Step 5 — Print and plot the results

from jmstate.utils import plot_mcmc_diagnostics, plot_params_history

# Compute and print summary statistics (nullity Wald statistics, p-values, AIC, BIC, etc.)
model.compute_summary().summary()

# Plot parameter history (stochastic optimization)...
plot_params_history(model)
plt.show()

# ...and MCMC sampler diagnostics
plot_mcmc_diagnostics(model)
plt.show()

License

See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jmstate-0.17.2.tar.gz (35.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jmstate-0.17.2-py3-none-any.whl (41.4 kB view details)

Uploaded Python 3

File details

Details for the file jmstate-0.17.2.tar.gz.

File metadata

  • Download URL: jmstate-0.17.2.tar.gz
  • Upload date:
  • Size: 35.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for jmstate-0.17.2.tar.gz
Algorithm Hash digest
SHA256 a744b7e5d9d6f76a94708e1593eee8d60c78ca0a0a608dc8a861ebd3a03dd540
MD5 6bf5d52c4d790169e4b4ed464e2a71a7
BLAKE2b-256 87c3e9853e99cbd984c300cf0adcf0b592a7876c838b199504dd10d9707d7c64

See more details on using hashes here.

File details

Details for the file jmstate-0.17.2-py3-none-any.whl.

File metadata

  • Download URL: jmstate-0.17.2-py3-none-any.whl
  • Upload date:
  • Size: 41.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for jmstate-0.17.2-py3-none-any.whl
Algorithm Hash digest
SHA256 340816640d2f67037e73fdf6968a9bc751163ad3f24a994d447fb65cb6a61773
MD5 969bb3cac1beadc17efa524d6a8938dc
BLAKE2b-256 84054e317af13c9d69d44e20a87424c093cd01e0570a56801691005753639f3d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page