Joint modeling with automatic differentiation
Project description
jmstate
jmstate is a Python package for nonlinear multi-state joint modeling of longitudinal and time-to-event data. Built on PyTorch, it enables flexible specification of regression and link functions — including neural networks — while still offering built-in parametric baseline hazards and utilities for inference and prediction.
The package implements the methodology from:
A General Framework for Joint Multi-State Models Félix Laplante & Christophe Ambroise (2025) — arXiv:2510.07128
Installation
pip install jmstate
Requirements: Python ≥ 3.10, PyTorch, scikit-learn, NumPy, Matplotlib, rich, tqdm.
Documentation
Full API reference and tutorials: jmstate documentation
The Model
jmstate fits a joint model that links a longitudinal biomarker process to multi-state event history through shared individual random effects.
Longitudinal sub-model
Individual observations follow
$$y_{ij} = h(t_{ij}, \psi_i) + \epsilon_{ij}, \qquad \epsilon_{ij} \sim \mathcal{N}(0, R)$$
where $h$ is a user-defined regression function (e.g. bi-exponential, logistic) and individual parameters are defined via
$$\psi_i = f(\gamma, X_i, b_i), \qquad b_i \sim \mathcal{N}(0, Q)$$
with $\gamma$ fixed population-level effects, $X_i$ covariates, and $b_i$ individual random effects.
Multi-state sub-model
Let $G = (V, E)$ be a directed graph, where $V$ denotes the set of states and $E \subseteq V \times V$ the set of admissible transitions. The graph encodes all possible paths of the multi-state process, allowing for competing, recurrent, or absorbing transitions. The hazard for a transition $k \to k'$ at time $t$ given entry time $t_0$ satisfies
$$\lambda^{k \to k'}(t_0, t) = \lambda_0^{k \to k'}(t_0, t) \exp\left( \alpha^{k \to k'} g^{k \to k'}(t, \psi_i) + \beta^{k \to k'} X_i \right),$$
where $\lambda_0^{k \to k'}$ are parametric baseline hazards, $g^{k \to k'}$ are link functions acting as a bridge between the longitudinal and the semi-Markov multi-state processes, and $\alpha^{k \to k'}$, $\beta^{k \to k'}$ are transition-specific coefficients.
The model supports arbitrary state graphs (recurrent, absorbing, monotone, etc.) under a semi-Markov assumption.
Estimation
Parameters are estimated by maximising the observed-data log-likelihood using the Fisher identity
$$\nabla_\theta \log \mathcal{L}(\theta; x) = \mathbb{E}{b \sim p(\cdot \mid x, \theta)}\left[ \nabla\theta \log \mathcal{L}(\theta; x, b) \right],$$
where $\mathcal{L}(\theta; x, b)$ is the complete likelihood of the data given the parameters and random effects.
This gradient is approximated via a Metropolis-Within-Gibbs MCMC sampler over the random effects, combined with a stochastic gradient ascent step. Convergence is monitored via an $R^2$-based stationarity test.
Quick Start
Step 1 — Define the model design
import torch
from jmstate.types import ModelDesign
# Individual parameters
def indiv_effects_fn(
fixed: torch.Tensor, x: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
return fixed * torch.exp(b) # (..., n, q)
# PK function: bi-exponential biomarker
def pk_fn(t: torch.Tensor, indiv_params: torch.Tensor, D: float = 1.0):
A, k, ka = indiv_params.chunk(3, dim=-1)
conc = A * (torch.exp(-k * t) - torch.exp(-ka * t))
return conc.unsqueeze(-1)
# PK integral function: bi-exponential cumulative link
def pk_integral_fn(t: torch.Tensor, indiv_params: torch.Tensor):
A, k, ka = indiv_params.chunk(3, dim=-1)
integral = A * (1 - torch.exp(-k * t)) / k - (1 - torch.exp(-ka * t)) / ka
return integral.unsqueeze(-1)
# Define the model design
design = ModelDesign(
indiv_effects_fn,
regression_fn=pk_fn,
link_fns={
(1, 1): pk_integral_fn,
(1, 2): pk_integral_fn,
},
)
Step 2 — Set initial parameters
from jmstate.functions.base_hazards import Exponential
from jmstate.types import ModelParameters, PrecisionParameters
# Define simple initial parameters
params = ModelParameters(
torch.ones(3),
PrecisionParameters.from_covariance(torch.eye(3), "diag"),
PrecisionParameters.from_covariance(torch.eye(1), "spherical"),
{(1, 1): Exponential(1.0), (1, 2): Exponential(1.0)},
{(1, 1): torch.zeros(1), (1, 2): torch.zeros(1)},
{(1, 1): torch.zeros(1), (1, 2): torch.zeros(1)},
)
Step 3 — Prepare data
from jmstate.types import ModelData
data = ModelData(
x, # (n, p) covariate matrix
t, # (m,) or (n, m) measurement times; NaN-pad if variable
y, # (n, m, d) longitudinal observations; NaN-pad if variable
trajectories, # list[list[tuple[float, Any]]]
c, # (n, 1) right-censoring times
)
Each trajectory is a chronologically ordered list of (time, state) tuples representing the individual's event history.
Step 4 — Fit the model
import matplotlib.pyplot as plt
from jmstate import MultiStateJointModel
optimizer = torch.optim.Adam(params.parameters(), lr=0.1)
model = MultiStateJointModel(design, params, optimizer)
metrics = model.fit(data)
Step 5 — Print and plot the results
from jmstate.utils import plot_mcmc_diagnostics, plot_params_history
# Compute and print summary statistics (nullity Wald statistics, p-values, AIC, BIC, etc.)
model.compute_summary().summary()
# Plot parameter history (stochastic optimization)...
plot_params_history(model)
plt.show()
# ...and MCMC sampler diagnostics
plot_mcmc_diagnostics(model)
plt.show()
License
See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jmstate-0.17.2.tar.gz.
File metadata
- Download URL: jmstate-0.17.2.tar.gz
- Upload date:
- Size: 35.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a744b7e5d9d6f76a94708e1593eee8d60c78ca0a0a608dc8a861ebd3a03dd540
|
|
| MD5 |
6bf5d52c4d790169e4b4ed464e2a71a7
|
|
| BLAKE2b-256 |
87c3e9853e99cbd984c300cf0adcf0b592a7876c838b199504dd10d9707d7c64
|
File details
Details for the file jmstate-0.17.2-py3-none-any.whl.
File metadata
- Download URL: jmstate-0.17.2-py3-none-any.whl
- Upload date:
- Size: 41.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
340816640d2f67037e73fdf6968a9bc751163ad3f24a994d447fb65cb6a61773
|
|
| MD5 |
969bb3cac1beadc17efa524d6a8938dc
|
|
| BLAKE2b-256 |
84054e317af13c9d69d44e20a87424c093cd01e0570a56801691005753639f3d
|