Skip to main content

Simulating trajectories from the stochastic Jansen-Rit neural mass model

Project description

jrnmm

Project Status ci

About

This package implements a pure Python/JAX port of the Jansen-Rit neural mass model (JRNMM) from sdbmpABC. The implementation is vectorized, so should be fairly fast in comparison.

Examples

You can use the package to simulate the readout of the JRNMM like this:

%matplotlib inline
import fybdthemes
import matplotlib.pyplot as plt

from jax import numpy as jnp
from jax import random as jr
from jax.scipy.signal import welch
from jrnmm import simulate

fybdthemes.set_theme()

# this samples 20 trajectories of length 8 / (1 / 128)
# each with initial condition [0.08, 18, 15, -0.5, 0, 0]
# and the same C, mu, sigma and gains
n = 20
C, mu, sigma, gain = 135, 220, 2000, 0.0
y = simulate(
    jr.PRNGKey(1),
    dt=1 / 128,
    t_end=8,
    initial_states=jnp.array([0.08, 18, 15, -0.5, 0, 0]),
    Cs=jnp.full(n, C),
    mus=jnp.full(n, mu),
    sigmas=jnp.full(n, sigma),
    gains=jnp.full(n, gain),
)
f, s = welch(y, fs=128, axis=1, nperseg=64)

_, axes = plt.subplots(figsize=(12, 3), ncols=2)
colors = fybdthemes.discrete_sequential_colors(n)
for i in range(20):
    axes[0].plot(y[i, :], color=colors[i], alpha=0.23)
    axes[1].plot(jnp.log(s[i, :]), color=colors[i], alpha=0.23)
axes[0].set_title("Readout", fontsize=13)
axes[1].set_title("Periodogram", fontsize=13)
plt.show()

png

Compare this to the sdbmpABC solution:

import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
sdbmp = importr("sdbmsABC")

rset_seed = robjects.r["set.seed"]
rchol = robjects.r["chol"]
rt = robjects.r["t"]

_, axes = plt.subplots(figsize=(12, 3), ncols=2)
for i in range(20):
    dt = 1/128
    y0 = robjects.FloatVector(list([0.08, 18, 15, -0.5, 0, 0]))
    grid = robjects.FloatVector(list(jnp.arange(0, 8.0, dt)))
    dm = sdbmp.exp_matJR(dt, 100, 50)
    cm = rt(
        rchol(sdbmp.cov_matJR(dt, robjects.FloatVector([0, 0, 0, 0.01, sigma, 1.0]), 100, 50))
    )
    y = jnp.array(
        sdbmp.Splitting_JRNMM_output_Cpp(
            dt, y0, grid, dm, cm, mu, C, 3.25, 22, 100, 50, 6, 0.56, 5.0
        )
    )
    f, s = welch(y, fs=128, nperseg=64)
    axes[0].plot(y, color=colors[i], alpha=0.23)
    axes[1].plot(jnp.log(s), color=colors[i], alpha=0.23)
axes[0].set_title("Readout", fontsize=13)
axes[1].set_title("Periodogram", fontsize=13)
plt.show()

png

Some timing comparison against sdbmpABC:

from timeit import default_timer as timer

n_repeat = 10
n_iter = 1_000

timings_r = []
for i in range(n_repeat):
    start = timer()
    for i in range(n_iter):
        y = jnp.array(
            sdbmp.Splitting_JRNMM_output_Cpp(
                dt, y0, grid, dm, cm, mu, C, 3.25, 22, 100, 50, 6, 0.56, 5.0
            )
        )
    end = timer() - start
    timings_r.append(end)

timings_jax = []
for i in range(n_repeat):
    start = timer()
    y = simulate(
        jr.PRNGKey(1),
        dt=dt,
        t_end=8,
        initial_states=jnp.array([0.08, 18, 15, -0.5, 0, 0]),
        Cs=jnp.full(n_iter, C),
        mus=jnp.full(n_iter, mu),
        sigmas=jnp.full(n_iter, sigma),
        gains=jnp.full(n_iter, gain),
    )
    end = timer() - start
    timings_jax.append(end)
print(f"Average time R/C++ to produce {n_iter} trajectories: {sum(timings_r) / n_repeat}")
print(f"Average time JAX to produce {n_iter} trajectories: {sum(timings_jax) / n_repeat}")
Average time R/C++ to produce 1000 trajectories: 3.9917746584003906
Average time JAX to produce 1000 trajectories: 0.6860150751999754

Installation

To install from GitHub, just call:

pip install git+https://github.com/dirmeier/jrnmm@<RELEASE>

where <RELEASE> versions can be found here.

Author

Simon Dirmeier sfyrbnd @ pm me

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jrnmm-0.1.1.post2.tar.gz (625.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jrnmm-0.1.1.post2-py3-none-any.whl (15.5 kB view details)

Uploaded Python 3

File details

Details for the file jrnmm-0.1.1.post2.tar.gz.

File metadata

  • Download URL: jrnmm-0.1.1.post2.tar.gz
  • Upload date:
  • Size: 625.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.15

File hashes

Hashes for jrnmm-0.1.1.post2.tar.gz
Algorithm Hash digest
SHA256 c2a75f54800f659b6ab9b86f73215bff7ddfdb9d9793c8ca1c3969413bb5298d
MD5 ef13bf181242b32dd22ae8cb3a4dda62
BLAKE2b-256 a6e19e8d6e1cf7d94f4189d488f909029409caf866f283f0562501bda4c15a93

See more details on using hashes here.

File details

Details for the file jrnmm-0.1.1.post2-py3-none-any.whl.

File metadata

  • Download URL: jrnmm-0.1.1.post2-py3-none-any.whl
  • Upload date:
  • Size: 15.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.15

File hashes

Hashes for jrnmm-0.1.1.post2-py3-none-any.whl
Algorithm Hash digest
SHA256 6cc889488872faed08abe6507327b1536a924851b999c3286034b5c43aff570a
MD5 4aa6433b35cf98399e5305495b73d4be
BLAKE2b-256 7c40e330d40e92a4a39fcb11898756be0b260b0c594867e4e689a16e77eb09fa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page