Skip to main content

Virtual brains w/ JAX

Project description

vbjax

vbjax is a Jax-based package for working with virtual brain style models.

Installation

Installs with pip install vbjax, but you use the source,

git clone https://github.com/ins-amu/vbjax
cd vbjax
pip install .[dev]

The primary additional dependency of vbjax is JAX, which itself depends only on NumPy, SciPy & opt-einsum, so it should be safe to add to your existing projects. Check Jax docs for CUDA use, but after the above pip step,

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

BUT because GPU software stack versions make aligning stars look like child's play, container images are available and auto-built w/ GitHub Actions, so you can use w/ Docker

docker run --rm -it ghcr.io/ins-amu/vbjax:main python3 -c 'import vbjax; print(vbjax.__version__)'

The images are built on Nvidia runtime images, so --gpus all is enough for Jax to discover the GPU(s).

Examples

Here's an all-to-all connected network with Montbrio-Pazo-Roxin mass model dynamics,

import vbjax as vb
import jax.numpy as np

def network(x, p):
    c = 0.03*x.sum(axis=1)
    return vb.mpr_dfun(x, c, p)

_, loop = vb.make_sde(dt=0.01, dfun=network, gfun=0.1)
zs = vb.randn(500, 2, 32)
xs = loop(zs[0], zs[1:], vb.mpr_default_theta)
vb.plot_states(xs, 'rV', jpg='example1', show=True)

While integrators and mass models tend to be the same across publications, but the network model itself varies (regions vs surface, stimulus etc), vbjax allows user to focus on defining the network and then getting time series. Because the work is done by Jax, this is all auto-differentiable, GPU-able so friendly to use with common machine learning algorithms.

Neural field

Here's a neural field,

import jax.numpy as np
import vbjax as vb

# setup local connectivity
lmax, nlat, nlon = 16, 32, 64
lc = vb.make_shtdiff(lmax=lmax, nlat=nlat, nlon=nlon)

# network dynamics
def net(x, p):
    c = lc(x[0]), 0.0
    return vb.mpr_dfun(x, c, p)

# solution + plot
x0 = vb.randn(2, nlat, nlon)*0.5 + np.r_[0.2,-2.0][:,None,None]
_, loop = vb.make_sde(0.1, net, 0.2)
zs = vb.randn(500, 2, nlat, nlon)
xt = loop(x0, zs, vb.mpr_default_theta._replace(eta=-3.9, cr=5.0))
vb.make_field_gif(xt[::10], 'example2.gif')

This example shows how the field forms patterns gradually despite the noise in the simulation.

Fitting an autoregressive process

Here's a 1-lag MVAR

import jax
import jax.numpy as np
import vbjax as vb

nn = 8
true_A = vb.randn(nn,nn)
_, loop = vb.make_sde(1, lambda x,A: -x+(A*x).mean(axis=1), 1)
x0 = vb.randn(nn)
zs = vb.randn(1000, nn)
xt = loop(x0, zs, true_A)

xt and true_A are the simulated time series and ground truth interaction matrices.

To fit anything we need a loss function & gradient descent,

def loss(est_A):
    return np.sum(np.square(xt - loop(x0, zs, est_A)))

grad_loss = jax.grad(loss)
est_A = np.ones((nn, nn))*0.3  # wrong
for i in range(51):
    est_A = est_A - 0.01*grad_loss(est_A)
    if i % 10 == 0:
        print('step', i, 'log loss', np.log(loss(est_A)))

print('mean sq err', np.square(est_A - true_A).mean())

which prints

step 0 log loss 5.8016257
step 10 log loss 3.687574
step 20 log loss 1.7174681
step 30 log loss -0.15798996
step 40 log loss -1.9851608
step 50 log loss -3.7805486
mean sq err 8.422789e-05

This is a pretty simple example but it's meant to show that any model you build with vbjax like this is usable with optimization or NumPyro's MCMC algorithms.

HPC usage

We use this on HPC systems, most easily with container images.

CSCS Piz Daint

Useful modules

module load daint-gpu
module load cudatoolkit/11.2.0_3.39-2.1__gf93aa1c
module load TensorFlow

then install in some Python environment; the default works fine

pip3 install "jax[cuda]==0.3.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip3 install "jaxlib==0.3.8+cuda11.cudnn805" -U -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

This provides an older version of JAX unfortunately.

The Sarus runtime can be used to make use of latest versions of vbjax and jax:

$ module load daint-gpu
$ module load sarus
$ sarus pull ghcr.io/ins-amu/vbjax:main
...
$ srun -p debug -A ich042 -C gpu --pty sarus run ghcr.io/ins-amu/vbjax:main python3 -c 'import jax; print(jax.numpy.zeros(32).device())'
...
gpu:0
JSC JUSUF

A nice module is available to get CUDA libs

module load cuDNN/8.6.0.163-CUDA-11.7

then you might set up a conda env,

wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh -b -p ~/conda
. ~/conda/bin/activate
conda create -n jax python=3.9 numpy scipy
source activate jax

once you have an env, install the CUDA-enabled JAX

pip3 install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

and check it works

(jax) [woodman1@jsfl02 ~]$ srun -A icei-hbp-2021-0002 -p develgpus --pty python3 -c 'import jax.numpy as np ; print(np.zeros(32).device())'
gpu:0

JSC also makes Singularity available, so the prebuilt image can be used

TODO
CEA

The prebuilt image is the best route:

TODO

Development

git clone https://github.com/ins-amu/vbjax
cd vbjax
pip install '.[dev]'
pytest

Installing SHTns

This library is used for some testing. It is impossible to install on Windows natively, so WSLx is required.

On macOS,

brew install fftw
git clone https://bitbucket.org/nschaeff/shtns
./configure --enable-python --disable-simd --prefix=/opt/homebrew
make -j && make install && python setup.py install

Releases

a release of version v1.2.3 requires following steps

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vbjax-0.0.9rc3.tar.gz (19.1 kB view details)

Uploaded Source

Built Distribution

vbjax-0.0.9rc3-py2.py3-none-any.whl (25.4 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file vbjax-0.0.9rc3.tar.gz.

File metadata

  • Download URL: vbjax-0.0.9rc3.tar.gz
  • Upload date:
  • Size: 19.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.2

File hashes

Hashes for vbjax-0.0.9rc3.tar.gz
Algorithm Hash digest
SHA256 f9c2b0edd4fde7be45a1a2ec097ee189ac49ebeacc04873805650772c133d19b
MD5 91f3e1c0ae8f957b13c0e1803ea13d5b
BLAKE2b-256 5f29f58925c88a6cb4fb9868a140912c14da5e670371b7f027e5f195bf9156ed

See more details on using hashes here.

File details

Details for the file vbjax-0.0.9rc3-py2.py3-none-any.whl.

File metadata

  • Download URL: vbjax-0.0.9rc3-py2.py3-none-any.whl
  • Upload date:
  • Size: 25.4 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.2

File hashes

Hashes for vbjax-0.0.9rc3-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 2ee454fab9253a560c5c8b7e7e9da881f84c264afb2a5b97f3b64e68ab49aa52
MD5 a7f6b080962b0893f8c285153d16a9c9
BLAKE2b-256 1d468be1295d94293c83e4143403c10f8b816d3294680ca5187a00a75e6eb7a5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page