Skip to main content

Sentence that is used as description in project.toml

Project description

jrnmm

Project Status ci

About

This package implements a pure Python/JAX port of the Jansen-Rit neural mass model (JRNMM) from sdbmpABC. The implementation is vectorized, so should be fairly fast in comparison.

Examples

You can use the package to simulate the readout of the JRNMM like this:

%matplotlib inline
import fybdthemes
import matplotlib.pyplot as plt

from jax import numpy as jnp
from jax import random as jr
from jax.scipy.signal import welch
from jrnmm import simulate

fybdthemes.set_theme()

# this samples 20 trajectories of length 8 / (1 / 128)
# each with initial condition [0.08, 18, 15, -0.5, 0, 0]
# and the same C, mu, sigma and gains
n = 20
C, mu, sigma, gain = 135, 220, 2000, 0.0
y = simulate(
    jr.PRNGKey(1),
    dt=1 / 128,
    t_end=8,
    initial_states=jnp.array([0.08, 18, 15, -0.5, 0, 0]),
    Cs=jnp.full(n, C),
    mus=jnp.full(n, mu),
    sigmas=jnp.full(n, sigma),
    gains=jnp.full(n, gain),
)
f, s = welch(y, fs=128, axis=1, nperseg=64)

_, axes = plt.subplots(figsize=(12, 3), ncols=2)
colors = fybdthemes.discrete_sequential_colors(n)
for i in range(20):
    axes[0].plot(y[i, :], color=colors[i], alpha=0.23)
    axes[1].plot(jnp.log(s[i, :]), color=colors[i], alpha=0.23)
axes[0].set_title("Readout", fontsize=13)
axes[1].set_title("Periodogram", fontsize=13)
plt.show()

png

Compare this to the sdbmpABC solution:

import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
sdbmp = importr("sdbmsABC")

rset_seed = robjects.r["set.seed"]
rchol = robjects.r["chol"]
rt = robjects.r["t"]

_, axes = plt.subplots(figsize=(12, 3), ncols=2)
for i in range(20):
    dt = 1/128
    y0 = robjects.FloatVector(list([0.08, 18, 15, -0.5, 0, 0]))
    grid = robjects.FloatVector(list(jnp.arange(0, 8.0, dt)))
    dm = sdbmp.exp_matJR(dt, 100, 50)
    cm = rt(
        rchol(sdbmp.cov_matJR(dt, robjects.FloatVector([0, 0, 0, 0.01, sigma, 1.0]), 100, 50))
    )
    y = jnp.array(
        sdbmp.Splitting_JRNMM_output_Cpp(
            dt, y0, grid, dm, cm, mu, C, 3.25, 22, 100, 50, 6, 0.56, 5.0
        )
    )
    f, s = welch(y, fs=128, nperseg=64)
    axes[0].plot(y, color=colors[i], alpha=0.23)
    axes[1].plot(jnp.log(s), color=colors[i], alpha=0.23)
axes[0].set_title("Readout", fontsize=13)
axes[1].set_title("Periodogram", fontsize=13)
plt.show()

png

Some timing comparison against sdbmpABC:

from timeit import default_timer as timer

n_repeat = 10
n_iter = 1_000

timings_r = []
for i in range(n_repeat):
    start = timer()
    for i in range(n_iter):
        y = jnp.array(
            sdbmp.Splitting_JRNMM_output_Cpp(
                dt, y0, grid, dm, cm, mu, C, 3.25, 22, 100, 50, 6, 0.56, 5.0
            )
        )
    end = timer() - start
    timings_r.append(end)

timings_jax = []
for i in range(n_repeat):
    start = timer()
    y = simulate(
        jr.PRNGKey(1),
        dt=dt,
        t_end=8,
        initial_states=jnp.array([0.08, 18, 15, -0.5, 0, 0]),
        Cs=jnp.full(n_iter, C),
        mus=jnp.full(n_iter, mu),
        sigmas=jnp.full(n_iter, sigma),
        gains=jnp.full(n_iter, gain),
    )
    end = timer() - start
    timings_jax.append(end)
print(f"Average time R/C++ to produce {n_iter} trajectories: {sum(timings_r) / n_repeat}")
print(f"Average time JAX to produce {n_iter} trajectories: {sum(timings_jax) / n_repeat}")
Average time R/C++ to produce 1000 trajectories: 3.9917746584003906
Average time JAX to produce 1000 trajectories: 0.6860150751999754

Installation

To install from GitHub, just call:

pip install git+https://github.com/dirmeier/jrnmm@<RELEASE>

where <RELEASE> versions can be found here.

Author

Simon Dirmeier sfyrbnd @ pm me

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jrnmm-0.1.0.post3.tar.gz (625.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jrnmm-0.1.0.post3-py3-none-any.whl (15.4 kB view details)

Uploaded Python 3

File details

Details for the file jrnmm-0.1.0.post3.tar.gz.

File metadata

  • Download URL: jrnmm-0.1.0.post3.tar.gz
  • Upload date:
  • Size: 625.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.15

File hashes

Hashes for jrnmm-0.1.0.post3.tar.gz
Algorithm Hash digest
SHA256 2d5f8b6ae50c4d2232f631a193a2eeb008e8b6f33004a41df714bb3743c0670e
MD5 a40a9b325baee148c35a970752431338
BLAKE2b-256 8a8eeadcddeca7f774444010613c99219bb15a07dcbbb422836c4a9915994531

See more details on using hashes here.

File details

Details for the file jrnmm-0.1.0.post3-py3-none-any.whl.

File metadata

  • Download URL: jrnmm-0.1.0.post3-py3-none-any.whl
  • Upload date:
  • Size: 15.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.15

File hashes

Hashes for jrnmm-0.1.0.post3-py3-none-any.whl
Algorithm Hash digest
SHA256 6c36cbbe510598ef52bc9c37697fc03faa0f166811b55e877d0c6e2844dd5354
MD5 4303d7afc0ae445c0bb6e26f0bc00c26
BLAKE2b-256 7849875e9fd6d2ac1199efbba22019a0729bc75de6c82137d7f4bad820465d15

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page