Generative Models using Jax
Project description
generax
generax provides implementations of flow based generative models. The library is built on top of Equinox which removes the need to worry about keeping track of model parameters.
key = random.PRNGKey(0) # JAX random key
x = ... # some data
# Create a flow model
model = NeuralSpline(input_shape=x.shape[1:],
n_flow_layers=3,
n_blocks=4,
hidden_size=32,
working_size=16,
n_spline_knots=8,
key=key)
# Data dependent initialization
model = model.data_dependent_init(x, key=key)
# Take multiple samples using vmap
keys = random.split(key, 1000)
samples = eqx.filter_vmap(model.sample)(keys)
# Compute the log probability of data
log_prob = eqx.filter_vmap(model.log_prob)(x)
There is also support for probability paths (time-dependent probability distributions) which can be used to train continuous normalizing flows with flow matching. See the examples on flow matching and multi-sample flow matching for more details.
Installation
generax is available on pip:
pip install generax
Training
Generax provides an easy interface to train these models:
trainer = Trainer(checkpoint_path='tmp/model_path')
model = trainer.train(model=model, # Generax model
objective=my_objective, # Objective function
evaluate_model=tester, # Testing function
optimizer=optimizer, # Optax optimizer
num_steps=10000, # Number of training steps
data_iterator=train_ds, # Training data iterator
double_batch=1000, # Train these many batches in a scan loop
checkpoint_every=1000, # Checkpoint interval
test_every=1000, # Test interval
retrain=True) # Retrain from checkpoint
See the examples folder for more details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file generax-0.1.2.tar.gz.
File metadata
- Download URL: generax-0.1.2.tar.gz
- Upload date:
- Size: 68.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c66a108811f5dab40c0052eb3e071e6a023d19e6e384bccc8fb9810113fcefca
|
|
| MD5 |
f65ca6dcbdd372fdde4ff78dc2223ec2
|
|
| BLAKE2b-256 |
5035f28703fd9a5ca9b38a7a5adc238b47d86d80afaf9a6e66e3b2a2c50d1e03
|
File details
Details for the file generax-0.1.2-py3-none-any.whl.
File metadata
- Download URL: generax-0.1.2-py3-none-any.whl
- Upload date:
- Size: 123.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c5b123455276d412909bd65f2c372f4d4f98aa94ac294b9ae09017c7a27ae936
|
|
| MD5 |
c07bf2558ed3900dbf881468de8b8f8c
|
|
| BLAKE2b-256 |
7dde189082c29455a3f7d45e15569fedae9f4195621a63c12212b502c7bf6a2a
|