Sentence that is used as description in project.toml
Project description
jrnmm
About
This package implements a pure Python/JAX port of the Jansen-Rit neural mass model (JRNMM) from sdbmpABC. The implementation is vectorized, so should be fairly fast in comparison.
Examples
You can use the package to simulate the readout of the JRNMM like this:
%matplotlib inline
import fybdthemes
import matplotlib.pyplot as plt
from jax import numpy as jnp
from jax import random as jr
from jax.scipy.signal import welch
from jrnmm import simulate
fybdthemes.set_theme()
# this samples 20 trajectories of length 8 / (1 / 128)
# each with initial condition [0.08, 18, 15, -0.5, 0, 0]
# and the same C, mu, sigma and gains
n = 20
C, mu, sigma, gain = 135, 220, 2000, 0.0
y = simulate(
jr.PRNGKey(1),
dt=1 / 128,
t_end=8,
initial_states=jnp.array([0.08, 18, 15, -0.5, 0, 0]),
Cs=jnp.full(n, C),
mus=jnp.full(n, mu),
sigmas=jnp.full(n, sigma),
gains=jnp.full(n, gain),
)
f, s = welch(y, fs=128, axis=1, nperseg=64)
_, axes = plt.subplots(figsize=(12, 3), ncols=2)
colors = fybdthemes.discrete_sequential_colors(n)
for i in range(20):
axes[0].plot(y[i, :], color=colors[i], alpha=0.23)
axes[1].plot(jnp.log(s[i, :]), color=colors[i], alpha=0.23)
axes[0].set_title("Readout", fontsize=13)
axes[1].set_title("Periodogram", fontsize=13)
plt.show()
Compare this to the sdbmpABC solution:
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
sdbmp = importr("sdbmsABC")
rset_seed = robjects.r["set.seed"]
rchol = robjects.r["chol"]
rt = robjects.r["t"]
_, axes = plt.subplots(figsize=(12, 3), ncols=2)
for i in range(20):
dt = 1/128
y0 = robjects.FloatVector(list([0.08, 18, 15, -0.5, 0, 0]))
grid = robjects.FloatVector(list(jnp.arange(0, 8.0, dt)))
dm = sdbmp.exp_matJR(dt, 100, 50)
cm = rt(
rchol(sdbmp.cov_matJR(dt, robjects.FloatVector([0, 0, 0, 0.01, sigma, 1.0]), 100, 50))
)
y = jnp.array(
sdbmp.Splitting_JRNMM_output_Cpp(
dt, y0, grid, dm, cm, mu, C, 3.25, 22, 100, 50, 6, 0.56, 5.0
)
)
f, s = welch(y, fs=128, nperseg=64)
axes[0].plot(y, color=colors[i], alpha=0.23)
axes[1].plot(jnp.log(s), color=colors[i], alpha=0.23)
axes[0].set_title("Readout", fontsize=13)
axes[1].set_title("Periodogram", fontsize=13)
plt.show()
Some timing comparison against sdbmpABC:
from timeit import default_timer as timer
n_repeat = 10
n_iter = 1_000
timings_r = []
for i in range(n_repeat):
start = timer()
for i in range(n_iter):
y = jnp.array(
sdbmp.Splitting_JRNMM_output_Cpp(
dt, y0, grid, dm, cm, mu, C, 3.25, 22, 100, 50, 6, 0.56, 5.0
)
)
end = timer() - start
timings_r.append(end)
timings_jax = []
for i in range(n_repeat):
start = timer()
y = simulate(
jr.PRNGKey(1),
dt=dt,
t_end=8,
initial_states=jnp.array([0.08, 18, 15, -0.5, 0, 0]),
Cs=jnp.full(n_iter, C),
mus=jnp.full(n_iter, mu),
sigmas=jnp.full(n_iter, sigma),
gains=jnp.full(n_iter, gain),
)
end = timer() - start
timings_jax.append(end)
print(f"Average time R/C++ to produce {n_iter} trajectories: {sum(timings_r) / n_repeat}")
print(f"Average time JAX to produce {n_iter} trajectories: {sum(timings_jax) / n_repeat}")
Average time R/C++ to produce 1000 trajectories: 3.9917746584003906
Average time JAX to produce 1000 trajectories: 0.6860150751999754
Installation
To install from GitHub, just call:
pip install git+https://github.com/dirmeier/jrnmm@<RELEASE>
where <RELEASE> versions can be found here.
Author
Simon Dirmeier sfyrbnd @ pm me
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jrnmm-0.1.0.post3.tar.gz.
File metadata
- Download URL: jrnmm-0.1.0.post3.tar.gz
- Upload date:
- Size: 625.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.9.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2d5f8b6ae50c4d2232f631a193a2eeb008e8b6f33004a41df714bb3743c0670e
|
|
| MD5 |
a40a9b325baee148c35a970752431338
|
|
| BLAKE2b-256 |
8a8eeadcddeca7f774444010613c99219bb15a07dcbbb422836c4a9915994531
|
File details
Details for the file jrnmm-0.1.0.post3-py3-none-any.whl.
File metadata
- Download URL: jrnmm-0.1.0.post3-py3-none-any.whl
- Upload date:
- Size: 15.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.9.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6c36cbbe510598ef52bc9c37697fc03faa0f166811b55e877d0c6e2844dd5354
|
|
| MD5 |
4303d7afc0ae445c0bb6e26f0bc00c26
|
|
| BLAKE2b-256 |
7849875e9fd6d2ac1199efbba22019a0729bc75de6c82137d7f4bad820465d15
|