Skip to main content

the galilei project.

Project description

galilei

Release Status CI Status CI Status Documentation Status License: MIT DOI

galilei is a python package that makes emulating a numerical functions easier and more composable. It supports multiple backends such as pytorch-based neural networks, GPy-based gaussian process regression, etc. As of now, it defaults to a jax+flax+optax backend which supports automatic differenciation of the emulated function and easy composibility with the rest of the jax-based eco-system.

The motivation of emulating a function is that sometimes computing a function could be a time consuming task, so one may need to find fast approximations of a function that's better than basic interpolation techniques. An emulated function, on the other hand, can runs blazingly fast on a normal GPU achieveing over many orders of magnitude speed up. This idea of emulating function is not new. In the field of cosmology we have powerful tools such as cosmopower and its derived works such as axionEmu, whose idea inspired this work. My aim in this work differs from the previous approach, as I intend to make a more generic and easily-composible functional emulator which can take any generic parametrized numerical function as an input and and return a function with the exact same signature that can be used as a drop-in replacement for the old function in existing code base without any modifications.

Features

  • Support multiple backends: torch, sklearn, gpy (for gaussian process regression), jax.
  • Flexible: Able to emulate generic numerical functions.
  • Automatic differenciable (supported by selected backends): emulated function is automatically differenciable and easily composible with jax-based tools.
  • Easy to use: just add a decorator @emulate and use your emulated function as a drop-in replacement of your existing function in code-base without additional modification.
  • Allow arbitrary transformation of function output before training through the use of Preconditioner.

Installation

pip install galilei

Basic usage

Suppose that we have an expensive function that we want to emulate

def myfun(a, b):
    # assume this is a slow function
    x = np.linspace(0, 10, 100)
    return np.sin(a*x) + np.sin(b*x)

If you want to emulate this function, you can simply add a decorator @emulate and supply the parameters that you want to evaluate this function at to build up the training data set.

from galilei import emulate

@emulate(samples={
    'a': np.random.rand(1000),
    'b': np.random.rand(1000)
})
def myfun(a, b):
    x = np.linspace(0, 10, 100)
    return np.sin(a*x) + np.sin(b*x)

Here we are just making 1000 pairs of random numbers from 0 to 1 to train our function. When executing these lines, the emulator will start training, and once it is done, the original myfun function will be automatically replaced with the emulated version and should behave in the same way, except much faster!

Training emulator...
100%|██████████| 500/500 [00:09<00:00, 50.50it/s, loss=0.023]
Ave Test loss: 0.025

Comparison

With the default jax backend, the emulated function is automatically jax compatible, which means one can easily compose them with jax machinary, such as in example below where I have compiled the emulated function with jit and then vectorized it over its first argument with vmap.

from jax import jit, vmap

vmap(jit(myfun), in_axes=(0, None))(np.linspace(0, 1, 10), 0.5)

Output:

Array([[-2.39813775e-02,  2.16133818e-02,  8.05920288e-02,
         1.66035295e-01,  2.01425016e-01,  2.42054626e-01,
         2.74079561e-01,  3.50277930e-01,  4.12616253e-01,
         4.33193207e-01,  4.82740909e-01,  5.66871405e-01,
         5.73131263e-01,  6.51429832e-01,  6.55564785e-01,
         ...

The emulated function will also be automatically differenciable regardless of the original implementation details. For example, we could easily evaluate its jacobian (without finite differencing) with

from jax import jacfwd

jacfwd(myfun, argnums=(0,1))(0.2, 0.8)

Output:

(Array([ 0.05104188,  0.18436229,  0.08595917,  0.06582363,  0.17270228, ...],      dtype=float32),
 Array([-3.3511031e-01,  1.2647966e-01,  4.3209594e-02,  2.4325712e-01, ...],      dtype=float32))

You can also easily save your trained model with the save option

@emulate(samples={
    'a': np.random.rand(100),
    'b': np.random.rand(100)
}, backend='sklearn', save="test.pkl")
def myfun(a, b):
    x = np.linspace(0, 10, 100)
    return np.sin(a*x) + np.sin(b*x)

and when you use it in production, simply load a pretrained model with

@emulate(backend='sklearn', load="test.pkl")
def myfun(a, b):
    x = np.linspace(0, 10, 100)
    return np.sin(a*x) + np.sin(b*x)

and your function will be replaced with a fast emulated version. Comparison2 For more detailed usage examples, see this notebook: open in colab

Roadmap

  • TODO add prebuild preconditioners
  • TODO support downloading files from web
  • TODO auto infer backend
  • TODO chains of preconditioners

Credits

This package was created with the ppw tool. For more information, please visit the project page.

If this package is helpful in your work, please consider citing:

@article{yguan_2023,
    title={galilei: a generic function emulator},
    DOI={10.5281/zenodo.7651315},
    publisher={Zenodo},
    author={Yilun Guan},
    year={2023},
    month={Feb}}

Free software: MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

galilei-0.2.0.tar.gz (16.4 kB view hashes)

Uploaded Source

Built Distribution

galilei-0.2.0-py3-none-any.whl (15.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page