Skip to main content

Generative Models using Jax

Project description

Generax

Generax provides implementations of different kinds of generative models. The library is built on top of Equinox which removes the need to worry about keeping track of model parameters. For example, the following code snippet shows how to create a neural spline flow and sample from it.

key = random.PRNGKey(0) # JAX random key
x = ... # some data

# Data dependent initialization
model = NeuralSpline(x=x,
                     key=key,
                     n_layers=3,
                     n_res_blocks=4,
                     hidden_size=32,
                     working_size=16)

# Sample from the model
samples = model.sample(key, n_samples=1000)

# Compute the log probability of data
log_prob = model.log_prob(x)

Installation

Roadmap

Implemented

  • Normalizing flows
  • Continuous normalizing flows
  • Diffusion models

And these models can be trained using a variety of methods including:

  • Maximum likelihood
  • Score matching
  • Flow matching
  • Variational inference

Training

Generax provides an easy interface to train these models:

trainer = Trainer(checkpoint_path='tmp/RealNVP')

model = trainer.train(model=model,              # Generax model
                      objective=max_likelihood, # Objective function
                      evaluate_model=tester,    # Testing function
                      optimizer=optimizer,      # Optax optimizer
                      num_steps=10000,          # Number of training steps
                      data_iterator=train_ds,   # Training data iterator
                      double_batch=1000,        # Train these many batches in a scan loop
                      checkpoint_every=1000,    # Checkpoint interval
                      test_every=1000,          # Test interval
                      retrain=True)             # Retrain from checkpoint

See the tutorial for an example.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

generax-0.0.1.tar.gz (22.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

generax-0.0.1-py3-none-any.whl (33.2 kB view details)

Uploaded Python 3

File details

Details for the file generax-0.0.1.tar.gz.

File metadata

  • Download URL: generax-0.0.1.tar.gz
  • Upload date:
  • Size: 22.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for generax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 b6920e05c80736be57de0e36ae98fcd32cb260b798b707d22bec559a2bb956aa
MD5 e3113a7c395f7d1b8dece3cf9b16b82f
BLAKE2b-256 725cddfdf4112c104e01834aac4af106f16da5bda7978b5c7fce74d0aa62786f

See more details on using hashes here.

File details

Details for the file generax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: generax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 33.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for generax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 35f543aa849e9deb0bc92a14d184d227bb4ffe7b37471366c239e1c58e025788
MD5 06ff0fa67deaa06f215c3e7cd0128d82
BLAKE2b-256 14203cb43c998a7ddffc4cb81c7b933aef50cbd61be71010a2b0ef66feaf8c44

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page