Generative Models using Jax
Project description
generax
generax provides implementations of different kinds of generative models. The library is built on top of Equinox which removes the need to worry about keeping track of model parameters. For example, the following code snippet shows how to create a neural spline flow and sample from it.
key = random.PRNGKey(0) # JAX random key
x = ... # some data
# Create a flow model
model = NeuralSpline(input_shape=x.shape[1:],
n_flow_layers=3,
n_blocks=4,
hidden_size=32,
working_size=16,
n_spline_knots=8,
key=key)
# Data dependent initialization
model = model.data_dependent_init(x, key=key)
# Sample from the model
samples = model.sample(key, n_samples=1000)
# Compute the log probability of data
log_prob = model.log_prob(x)
Installation
generax is available on pip:
pip install generax
Roadmap
Implemented
- Normalizing flows
- Continuous normalizing flows
- Diffusion models
And these models can be trained using a variety of methods including:
- Maximum likelihood
- Score matching
- Flow matching
- Variational inference
Training
Generax provides an easy interface to train these models:
trainer = Trainer(checkpoint_path='tmp/RealNVP')
model = trainer.train(model=model, # Generax model
objective=max_likelihood, # Objective function
evaluate_model=tester, # Testing function
optimizer=optimizer, # Optax optimizer
num_steps=10000, # Number of training steps
data_iterator=train_ds, # Training data iterator
double_batch=1000, # Train these many batches in a scan loop
checkpoint_every=1000, # Checkpoint interval
test_every=1000, # Test interval
retrain=True) # Retrain from checkpoint
See the tutorial for an example.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file generax-0.0.3.tar.gz.
File metadata
- Download URL: generax-0.0.3.tar.gz
- Upload date:
- Size: 40.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
79c59e748ed80f307ec13a0926a13c89716ac11f1e7c23541fffe97fef714fac
|
|
| MD5 |
270e33f6fd1f3cc12605e815d3a66510
|
|
| BLAKE2b-256 |
3fb0dd0eb0f406b41882638984c528bde88f2d5e7c0bed7b5ebde658e2d3a8b4
|
File details
Details for the file generax-0.0.3-py3-none-any.whl.
File metadata
- Download URL: generax-0.0.3-py3-none-any.whl
- Upload date:
- Size: 62.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
50cc8a2d82b2a7d6fa0b47a91dca98d8693d7943d3db6f8c0f3a7fb2e15e4798
|
|
| MD5 |
2cd3b3133d4c7993ff05719e585212e0
|
|
| BLAKE2b-256 |
9816c60af9d281f68ed3954a77aff995995bfe55b6db756f7121e63744959e54
|