Skip to main content

the galilei project.

Project description

galilei

Release Status CI Status CI Status Documentation Status License: MIT DOI

galilei is a python package that makes emulating a numerical functions easier and more composable. It supports multiple backends such as pytorch-based neural networks, GPy-based gaussian process regression, etc. As of now, it defaults to a jax+flax+optax backend which supports automatic differenciation of the emulated function and easy composibility with the rest of the jax-based eco-system.

The motivation of emulating a function is that sometimes computing a function could be a time consuming task, so one may need to find fast approximations of a function that's better than basic interpolation techniques. Machine learning techniques such as neural networks offer a solution to this problem which is generic enough to emulate any arbitrarily shaped function. In contrast to the original function, a neural-network based emulator runs blazingly fast and even more so with GPU, often achieveing over many orders of magnitude speed-up.

This idea of emulating function is not new. In the field of cosmology we have powerful tools such as cosmopower and its derived works such as axionEmu, whose idea inspired this work. My aim in this work differs from the previous approach, as I intend to make a both generic and easily-composible function emulator that can take any arbitrary parametrized numerical function as an input, and return a function with the exact same signature as a drop-in replacement in existing code base, with no additional code changes needed. In addition, I also focus on making the emulated function automatically differenciable regardless of its original implementation.

Features

  • Support multiple backends: torch, sklearn, gpy (for gaussian process regression), jax.
  • Flexible: Able to emulate generic numerical functions.
  • Automatic differenciable (supported by selected backends): emulated function is automatically differenciable and easily composible with jax-based tools.
  • Easy to use: just add a decorator @emulate and use your emulated function as a drop-in replacement of your existing function in code-base without additional modification.
  • Allow arbitrary transformation of function output before training through the use of Preconditioner.

Installation

pip install galilei

Basic usage

Suppose that we have an expensive function that we want to emulate

def myfun(a, b):
    # assume this is a slow function
    x = np.linspace(0, 10, 100)
    return np.sin(a*x) + np.sin(b*x)

If you want to emulate this function, you can simply add a decorator @emulate and supply the parameters that you want to evaluate this function at to build up the training data set.

from galilei import emulate

@emulate(samples={
    'a': np.random.rand(1000),
    'b': np.random.rand(1000)
})
def myfun(a, b):
    x = np.linspace(0, 10, 100)
    return np.sin(a*x) + np.sin(b*x)

Here we are just making 1000 pairs of random numbers from 0 to 1 to train our function. When executing these lines, the emulator will start training, and once it is done, the original myfun function will be automatically replaced with the emulated version and should behave in the same way, except much faster!

Training emulator...
100%|██████████| 500/500 [00:09<00:00, 50.50it/s, loss=0.023]
Ave Test loss: 0.025

Comparison

With the default jax backend, the emulated function is automatically jax compatible, which means one can easily compose them with jax machinary, such as in example below where I have compiled the emulated function with jit and then vectorized it over its first argument with vmap.

from jax import jit, vmap

vmap(jit(myfun), in_axes=(0, None))(np.linspace(0, 1, 10), 0.5)

Output:

Array([[-2.39813775e-02,  2.16133818e-02,  8.05920288e-02,
         1.66035295e-01,  2.01425016e-01,  2.42054626e-01,
         2.74079561e-01,  3.50277930e-01,  4.12616253e-01,
         4.33193207e-01,  4.82740909e-01,  5.66871405e-01,
         5.73131263e-01,  6.51429832e-01,  6.55564785e-01,
         ...

The emulated function will also be automatically differenciable regardless of the original implementation details. For example, we could easily evaluate its jacobian (without finite differencing) with

from jax import jacfwd

jacfwd(myfun, argnums=(0,1))(0.2, 0.8)

Output:

(Array([ 0.05104188,  0.18436229,  0.08595917,  0.06582363,  0.17270228, ...],      dtype=float32),
 Array([-3.3511031e-01,  1.2647966e-01,  4.3209594e-02,  2.4325712e-01, ...],      dtype=float32))

You can also easily save your trained model with the save option

@emulate(samples={
    'a': np.random.rand(100),
    'b': np.random.rand(100)
}, backend='sklearn', save="test.pkl")
def myfun(a, b):
    x = np.linspace(0, 10, 100)
    return np.sin(a*x) + np.sin(b*x)

and when you use it in production, simply load a pretrained model with

@emulate(backend='sklearn', load="test.pkl")
def myfun(a, b):
    x = np.linspace(0, 10, 100)
    return np.sin(a*x) + np.sin(b*x)

and your function will be replaced with a fast emulated version. Comparison2

It's also possible to sample training points based on latin hypercube using the build_samples function. For example, here I build a 100 sample latin hypercube for a given range of a and b

from galilei.sampling import build_samples
@emulate(
    samples=build_samples({"a": [0, 2], "b": [0, 2]}, 100),
    backend='sklearn'
)
def myfun(a, b):
    x = np.linspace(0, 10, 100)
    return np.sin(a*x) + np.sin(b*x)

Sometimes one might want to collect training data only instead of training the emulator. This could be done by

from galilei.experimental import collect

@collect(
    samples=build_samples({"a": [0, 2], "b": [0, 2]}, 100),
    save="collection.pkl",
    mpi=True
)
def myfun(a, b):
    x = np.linspace(0, 10, 100)
    return np.sin(a*x) + np.sin(b*x)

which will save a precomputed collection to collection.pkl for future loading. Note that the option to use mpi depends on the user having a working mpi4py which needs to be installed by the user. The collection could be loaded for training emulator using

@emulate(
    collection="collection.pkl",
    backend='sklearn'
)
def myfunc(a, b):
    raise Exception()

myfunc(1, 1)

since the function will not be evaluated in this case, we note that the implementation of myfunc makes no difference (otherwise it would have given an error).

For more usage examples, see this notebook: open in colab

Roadmap

  • TODO add prebuild preconditioners
  • TODO support downloading files from web
  • TODO auto infer backend
  • TODO chains of preconditioners

Credits

This package was created with the ppw tool. For more information, please visit the project page.

If this package is helpful in your work, please consider citing:

@article{yguan_2023,
    title={galilei: a generic function emulator},
    DOI={10.5281/zenodo.7651315},
    publisher={Zenodo},
    author={Yilun Guan},
    year={2023},
    month={Feb}}

Free software: MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

galilei-0.3.1.tar.gz (19.8 kB view details)

Uploaded Source

Built Distribution

galilei-0.3.1-py3-none-any.whl (18.4 kB view details)

Uploaded Python 3

File details

Details for the file galilei-0.3.1.tar.gz.

File metadata

  • Download URL: galilei-0.3.1.tar.gz
  • Upload date:
  • Size: 19.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.10.8 Linux/5.15.0-58-generic

File hashes

Hashes for galilei-0.3.1.tar.gz
Algorithm Hash digest
SHA256 ea1dfd876d700269d3b0b56a8ef69d928d51041a0964f2ab3399d79df9f9d2d6
MD5 6306b2bbef54c35351b416a794e3d2db
BLAKE2b-256 935ad5deab9a4ef180a3376dadfbc7216d7fb5ab2a3429968330cdaa1ef77364

See more details on using hashes here.

File details

Details for the file galilei-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: galilei-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 18.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.10.8 Linux/5.15.0-58-generic

File hashes

Hashes for galilei-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ba103d0673b49d83bc5e2ed4b2347e204a626823486ca38d250f40eed1f2fa49
MD5 2c4a96be62b27a39014b91375840d496
BLAKE2b-256 01b157cd9b9412510f3cbfa6cafd2d07f6e272bb9b124d7bc23d435b9deac29d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page