Virtual brains w/ JAX
Project description
Accelerated inference on fields: virtual brains in Jax
basic plan
- implement shtns like api + custom kernels for shtlc
- focus on kernel design parity w/ sparse matrix
- epi2d, heun, obs2d, scan loop in time
- explore batching in time for hybrid
testing different methods, Jax allows
- numpyro: HMC & VI
- batched eval: efficient parameters sweeps for SBI
- run notebooks as tests with jupyter execute via dockerfile + gh actions
setup
conda
The primary dependency jax
is easily installable with conda, e.g.
curl -LO https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh -p $SCRATCH/conda-vbjax -b
. $SCRATCH/conda-vbjax/bin/activate
conda install -y -q jax
pip install vbjax
Dockerfile
See cpu.dockerfile
or gpu.dockerfile
.
pipenv
pipenv is used to manage the dependencies, so you can install deps in a virtualenv and run Jupyter notebook like so:
pip install -U pipenv
pipenv install -d
pipenv run jupyter notebook
installing shtns on mac
In the Python environment,
brew install fftw
git clone https://bitbucket.org/nschaeff/shtns
./configure --enable-python --disable-simd --prefix=/opt/homebrew
make -j && make install && python setup.py install
Pipenv
Some packages that may be used are not in the Pipfile because pipenv doesn't manage to install them.
- numba
- tvb-data & tvb-library
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
vbjax-0.0.2.tar.gz
(10.4 kB
view hashes)
Built Distribution
vbjax-0.0.2-py2.py3-none-any.whl
(12.9 kB
view hashes)
Close
Hashes for vbjax-0.0.2-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 19084668e0ebd5a1abb1356fac023d0acbdb37065015c1f5d41e4764742669bc |
|
MD5 | c81c8313c20ea3f54b178763e388e794 |
|
BLAKE2b-256 | 38f936a9079f33e09da8086c06e5a9be83dbab0ce8bd1a8e55119e9e270200b2 |