Skip to main content

Generative Models using Jax

Project description

generax

generax provides implementations of flow based generative models. The library is built on top of Equinox which removes the need to worry about keeping track of model parameters.

key = random.PRNGKey(0) # JAX random key
x = ... # some data

# Create a flow model
model = NeuralSpline(input_shape=x.shape[1:],
                     n_flow_layers=3,
                     n_blocks=4,
                     hidden_size=32,
                     working_size=16,
                     n_spline_knots=8,
                     key=key)

# Data dependent initialization
model = model.data_dependent_init(x, key=key)

# Take multiple samples using vmap
keys = random.split(key, 1000)
samples = eqx.filter_vmap(model.sample)(keys)

# Compute the log probability of data
log_prob = eqx.filter_vmap(model.log_prob)(x)

There is also support for probability paths (time-dependent probability distributions) which can be used to train continuous normalizing flows with flow matching. See the examples on flow matching and multi-sample flow matching for more details.

Samples

Installation

generax is available on pip:

pip install generax

Training

Generax provides an easy interface to train these models:

trainer = Trainer(checkpoint_path='tmp/model_path')

model = trainer.train(model=model,              # Generax model
                      objective=my_objective,   # Objective function
                      evaluate_model=tester,    # Testing function
                      optimizer=optimizer,      # Optax optimizer
                      num_steps=10000,          # Number of training steps
                      data_iterator=train_ds,   # Training data iterator
                      double_batch=1000,        # Train these many batches in a scan loop
                      checkpoint_every=1000,    # Checkpoint interval
                      test_every=1000,          # Test interval
                      retrain=True)             # Retrain from checkpoint

See the examples folder for more details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

generax-0.1.2.tar.gz (68.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

generax-0.1.2-py3-none-any.whl (123.8 kB view details)

Uploaded Python 3

File details

Details for the file generax-0.1.2.tar.gz.

File metadata

  • Download URL: generax-0.1.2.tar.gz
  • Upload date:
  • Size: 68.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for generax-0.1.2.tar.gz
Algorithm Hash digest
SHA256 c66a108811f5dab40c0052eb3e071e6a023d19e6e384bccc8fb9810113fcefca
MD5 f65ca6dcbdd372fdde4ff78dc2223ec2
BLAKE2b-256 5035f28703fd9a5ca9b38a7a5adc238b47d86d80afaf9a6e66e3b2a2c50d1e03

See more details on using hashes here.

File details

Details for the file generax-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: generax-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 123.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for generax-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c5b123455276d412909bd65f2c372f4d4f98aa94ac294b9ae09017c7a27ae936
MD5 c07bf2558ed3900dbf881468de8b8f8c
BLAKE2b-256 7dde189082c29455a3f7d45e15569fedae9f4195621a63c12212b502c7bf6a2a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page