Skip to main content

Simulating trajectories from the stochastic Jansen-Rit neural mass model

Project description

jrnmm

Project Status ci

About

This package implements a pure Python/JAX port of the Jansen-Rit neural mass model (JRNMM) from sdbmpABC. The implementation is vectorized, so should be fairly fast in comparison.

Examples

You can use the package to simulate the readout of the JRNMM like this:

%matplotlib inline
import fybdthemes
import matplotlib.pyplot as plt

from jax import numpy as jnp
from jax import random as jr
from jax.scipy.signal import welch
from jrnmm import simulate

fybdthemes.set_theme()

# this samples 20 trajectories of length 8 / (1 / 128)
# each with initial condition [0.08, 18, 15, -0.5, 0, 0]
# and the same C, mu, sigma and gains
n = 20
C, mu, sigma, gain = 135, 220, 2000, 0.0
y = simulate(
    jr.PRNGKey(1),
    dt=1 / 128,
    t_end=8,
    initial_states=jnp.array([0.08, 18, 15, -0.5, 0, 0]),
    Cs=jnp.full(n, C),
    mus=jnp.full(n, mu),
    sigmas=jnp.full(n, sigma),
    gains=jnp.full(n, gain),
)
f, s = welch(y, fs=128, axis=1, nperseg=64)

_, axes = plt.subplots(figsize=(12, 3), ncols=2)
colors = fybdthemes.discrete_sequential_colors(n)
for i in range(20):
    axes[0].plot(y[i, :], color=colors[i], alpha=0.23)
    axes[1].plot(jnp.log(s[i, :]), color=colors[i], alpha=0.23)
axes[0].set_title("Readout", fontsize=13)
axes[1].set_title("Periodogram", fontsize=13)
plt.show()

png

Compare this to the sdbmpABC solution:

import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
sdbmp = importr("sdbmsABC")

rset_seed = robjects.r["set.seed"]
rchol = robjects.r["chol"]
rt = robjects.r["t"]

_, axes = plt.subplots(figsize=(12, 3), ncols=2)
for i in range(20):
    dt = 1/128
    y0 = robjects.FloatVector(list([0.08, 18, 15, -0.5, 0, 0]))
    grid = robjects.FloatVector(list(jnp.arange(0, 8.0, dt)))
    dm = sdbmp.exp_matJR(dt, 100, 50)
    cm = rt(
        rchol(sdbmp.cov_matJR(dt, robjects.FloatVector([0, 0, 0, 0.01, sigma, 1.0]), 100, 50))
    )
    y = jnp.array(
        sdbmp.Splitting_JRNMM_output_Cpp(
            dt, y0, grid, dm, cm, mu, C, 3.25, 22, 100, 50, 6, 0.56, 5.0
        )
    )
    f, s = welch(y, fs=128, nperseg=64)
    axes[0].plot(y, color=colors[i], alpha=0.23)
    axes[1].plot(jnp.log(s), color=colors[i], alpha=0.23)
axes[0].set_title("Readout", fontsize=13)
axes[1].set_title("Periodogram", fontsize=13)
plt.show()

png

Some timing comparison against sdbmpABC:

from timeit import default_timer as timer

n_repeat = 10
n_iter = 1_000

timings_r = []
for i in range(n_repeat):
    start = timer()
    for i in range(n_iter):
        y = jnp.array(
            sdbmp.Splitting_JRNMM_output_Cpp(
                dt, y0, grid, dm, cm, mu, C, 3.25, 22, 100, 50, 6, 0.56, 5.0
            )
        )
    end = timer() - start
    timings_r.append(end)

timings_jax = []
for i in range(n_repeat):
    start = timer()
    y = simulate(
        jr.PRNGKey(1),
        dt=dt,
        t_end=8,
        initial_states=jnp.array([0.08, 18, 15, -0.5, 0, 0]),
        Cs=jnp.full(n_iter, C),
        mus=jnp.full(n_iter, mu),
        sigmas=jnp.full(n_iter, sigma),
        gains=jnp.full(n_iter, gain),
    )
    end = timer() - start
    timings_jax.append(end)
print(f"Average time R/C++ to produce {n_iter} trajectories: {sum(timings_r) / n_repeat}")
print(f"Average time JAX to produce {n_iter} trajectories: {sum(timings_jax) / n_repeat}")
Average time R/C++ to produce 1000 trajectories: 3.9917746584003906
Average time JAX to produce 1000 trajectories: 0.6860150751999754

Installation

To install from GitHub, just call:

pip install git+https://github.com/dirmeier/jrnmm@<RELEASE>

where <RELEASE> versions can be found here.

Author

Simon Dirmeier sfyrbnd @ pm me

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jrnmm-0.1.1.post1.tar.gz (625.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jrnmm-0.1.1.post1-py3-none-any.whl (15.4 kB view details)

Uploaded Python 3

File details

Details for the file jrnmm-0.1.1.post1.tar.gz.

File metadata

  • Download URL: jrnmm-0.1.1.post1.tar.gz
  • Upload date:
  • Size: 625.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.15

File hashes

Hashes for jrnmm-0.1.1.post1.tar.gz
Algorithm Hash digest
SHA256 ced7e6cfdbe09f92b57025629c34be41f2bf5575d2b1c2bb6fd4c6818056bda6
MD5 6253a313543c1a2cef942ea16ecc427b
BLAKE2b-256 edcc380486f2dc50563c113382522e78a6fa2bb9ac9d31b1289868c29b3f6f05

See more details on using hashes here.

File details

Details for the file jrnmm-0.1.1.post1-py3-none-any.whl.

File metadata

  • Download URL: jrnmm-0.1.1.post1-py3-none-any.whl
  • Upload date:
  • Size: 15.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.15

File hashes

Hashes for jrnmm-0.1.1.post1-py3-none-any.whl
Algorithm Hash digest
SHA256 6b822abb3fb149f23e45e05bda57595fbdde5ebac194890f7fce80b07c8acc41
MD5 426a81a0c4b7413361dc50c7849ffb5f
BLAKE2b-256 292d84f291d28c0f5fb305d3a5475152c40523295f1a890b5e3c159a60f3fb0c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page