lightweight library that allows to automatically parse and convert SBML models into python models written end-to-end in JAX
Project description
About
SBMLtoODEjax is a lightweight library that allows to automatically parse and convert SBML models into python models written end-to-end in JAX, a high-performance numerical computing library with automatic differentiation capabilities. SBMLtoODEjax is targeted at researchers that aim to incorporate SBML-specified ordinary differential equation (ODE) models into their python projects and machine learning pipelines, in order to perform efficient numerical simulation and optimization with only a few lines of code (by taking advantage of JAX’s core transformation features).
SBMLtoODEjax extends SBMLtoODEpy, a python library developed in 2019 for converting SBML files into python files written in Numpy/Scipy. The chosen conventions for the generated variables and modules are slightly different from the standard SBML conventions (used in the SBMLtoODEpy library) with the aim here to accommodate for more flexible manipulations while preserving JAX-like functional programming style.
👉 In short, SBMLtoODEjax facilitates the re-use of biological network models and their manipulation in python projects while tailoring them to take advantage of JAX main features for efficient and parallel computations.
📖 The documentation, notebook tutorials and public APU are available at https://developmentalsystems.org/sbmltoodejax/.
Installation
The latest stable release of SBMLtoODEjax
can be installed via pip
:
pip install sbmltoodejax
Requires SBMLtoODEpy, JAX (cpu) and Equinox.
Why use SBMLtoODEjax?
Simplicity and extensibility
SBMLtoODEjax retains the simplicity of the original SBMLtoODEPy library to facilitate incorporation and refactoring of the ODE models into one’s own python projects. As shown below, with only a few lines of python code one can load and simulate existing SBML files.
Example code (left) and output snapshot (right) reproducing original simulation results of Kholodenko 2000's paper hosted on BioModels website.
👉 Check our Numerical Simulation tutorial to reproduce results yourself and see more examples.
JAX-friendly
The generated python models are tailored to take advantage of JAX main features.
class ModelRollout(eqx.Module):
def __call__(self, n_steps, y0, w0, c, t0=0.0):
@jit # use of jit transformation decorator
def f(carry, x):
y, w, c, t = carry
return self.modelstepfunc(y, w, c, t, self.deltaT), (y, w, t)
# use of scan primitive to replace for loop (and reduce compilation time)
(y, w, c, t), (ys, ws, ts) = lax.scan(f, (y0, w0, c, t0), jnp.arange(n_steps))
ys = jnp.moveaxis(ys, 0, -1)
ws = jnp.moveaxis(ws, 0, -1)
return ys, ws, ts
As shown above, model rollouts use jit
transformation and scan
primitive to reduce compilation and execution time of
the recursive ODE integration steps, which is particularly useful when running large numbers of steps (long reaction
times). Models also inherit from the Equinox module abstraction and are registered as PyTree
containers, which facilitates the application of JAX core transformations to any SBMLtoODEjax object.
Efficiency simulation and optimization
The application of JAX core transformations, such as just-in-time
compilation (jit
), automatic vectorization (vmap
) and automatic differentiation (grad
), to the generated models make it very easy (and
seamless) to efficiently run simulations in parallel.
For instance, as shown below, with only a few lines of python code one can vectorize calls to model rollout and perform batched computations efficiently, which is particularly useful when considering large batch sizes. (left) Example code to vectorize calls to model rollout (right) Results of a (rudimentary) benchmark comparing the average simulation time of models implemented with SBMLtoODEpy versus SBMLtoODEjax (for different number of rollouts i.e. batch size).
👉 Check our Benchmarking notebook for additional details on the benchmark results.
Finally, as shown below, SBMLtoODEjax models can also be integrated within Optax pipelines, a gradient processing and optimization library for JAX, allowing to optimize model parameters and/or external interventions with stochastic gradient descent.
(left) Default simulation results of biomodel #145 which models ATP-induced intracellular calcium oscillations, and target sine-wave pattern for Ca_Cyt concentration. (middle) Training loss obtained when running the Optax optimization loop, with Adam optimizer, over the model kinematic parameters c. (right) Simulation results obtained after optimization.
👉 Check our Gradient Descent tutorial to reproduce the result yourself and try more-advanced optimization usages.
All contributions are welcome!
SBMLtoODEjax is in its early stage and any sort of contribution will be highly appreciated.
Suggested contributions
They are several use cases that are not handled by the current codebase including:
- Events: SBML files with events (discrete occurrences that can trigger discontinuous changes in the model) are not handled
- Math Functions: we handle a large portion, but not all, of functions possibly-used in SBML files (see
mathFuncs
insbmltoodejax.modulegeneration.GenerateModel
) - Custom solvers: To integrate the model's equation, we use jax experimental
odeint
solver but do not yet allow for other solvers. - NaN/Negative values: numerical simulation sometimes leads to NaN values (or negative values for the species amounts) which could either be due to wrong parsing or solver issues
This means that a large portion of the possible SBML files cannot yet be simulated, for instance as we detail on the below image, out of 1048 curated models that one can load from the BioModels website, only 232 can successfully be simulated (given the default initial conditions) in SBMLtoODEjax:
👉 Please consider contributing and check our Contribution Guidelines to learn how to do so.
License
The SBMLtoODEjax project is licensed under the MIT license.
Acknowledgements
SBMLtoODEjax builds on:
- SBMLtoODEpy's parsing and conversion of SBML files, by Steve M. Ruggiero and Ashlee N. Ford
- JAX's composable transformations, by the Google team
- Equinox's module abstraction, by Patrick Kidger
- BasiCO's access the BioModels REST api, by the COPASI team
Our documentation was also inspired by the GPJax documentation, by Thomas Pinder and team.
Citing SBMLtoODEjax
If you use SBMLtoODEjax in your research, please cite the paper:
@inproceedings{etcheverry2023sbmltoodejax,
title={SBMLtoODEjax: Efficient Simulation and Optimization of Biological Network Models in JAX},
author={Mayalen Etcheverry and Michael Levin and Clement Moulin-Frier and Pierre-Yves Oudeyer},
booktitle={NeurIPS 2023 AI for Science Workshop},
year={2023},
url={https://openreview.net/forum?id=exP6UntwqJ}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file sbmltoodejax-0.4.2.tar.gz
.
File metadata
- Download URL: sbmltoodejax-0.4.2.tar.gz
- Upload date:
- Size: 20.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | af982fdae54808231d4d238b60da6fc5173fe395b3dc373ee0bac7c0d0100685 |
|
MD5 | 8774d93a6ef7a56e42867c532a0d9237 |
|
BLAKE2b-256 | 4bc88eeb90011592916fc10211e49ead71193465e6afcfd2cadcbb9dd12ab7ff |
File details
Details for the file sbmltoodejax-0.4.2-py3-none-any.whl
.
File metadata
- Download URL: sbmltoodejax-0.4.2-py3-none-any.whl
- Upload date:
- Size: 17.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c055fa0fa8ce1d833e20a12b0cbc8df2d5b99639546b99997cb1efbe1500cd22 |
|
MD5 | 7fcb12e76629feb2c4bf5f2c9d538903 |
|
BLAKE2b-256 | 2691c0787c41f2bd4d8dba132d112d2a60af1bcae19687c9cbd17339a6f130f6 |