Skip to main content

Generative Models using Jax

Project description

generax

generax provides implementations of different kinds of generative models. The library is built on top of Equinox which removes the need to worry about keeping track of model parameters. For example, the following code snippet shows how to create a neural spline flow and sample from it.

key = random.PRNGKey(0) # JAX random key
x = ... # some data

# Create a flow model
model = NeuralSpline(input_shape=x.shape[1:],
                     n_flow_layers=3,
                     n_blocks=4,
                     hidden_size=32,
                     working_size=16,
                     n_spline_knots=8,
                     key=key)

# Data dependent initialization
model = model.data_dependent_init(x, key=key)

# Sample from the model
samples = model.sample(key, n_samples=1000)

# Compute the log probability of data
log_prob = model.log_prob(x)

Installation

generax is available on pip:

pip install generax

Roadmap

Implemented

  • Normalizing flows
  • Continuous normalizing flows
  • Diffusion models

And these models can be trained using a variety of methods including:

  • Maximum likelihood
  • Score matching
  • Flow matching
  • Variational inference

Training

Generax provides an easy interface to train these models:

trainer = Trainer(checkpoint_path='tmp/RealNVP')

model = trainer.train(model=model,              # Generax model
                      objective=max_likelihood, # Objective function
                      evaluate_model=tester,    # Testing function
                      optimizer=optimizer,      # Optax optimizer
                      num_steps=10000,          # Number of training steps
                      data_iterator=train_ds,   # Training data iterator
                      double_batch=1000,        # Train these many batches in a scan loop
                      checkpoint_every=1000,    # Checkpoint interval
                      test_every=1000,          # Test interval
                      retrain=True)             # Retrain from checkpoint

See the tutorial for an example.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

generax-0.0.3.tar.gz (40.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

generax-0.0.3-py3-none-any.whl (62.9 kB view details)

Uploaded Python 3

File details

Details for the file generax-0.0.3.tar.gz.

File metadata

  • Download URL: generax-0.0.3.tar.gz
  • Upload date:
  • Size: 40.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for generax-0.0.3.tar.gz
Algorithm Hash digest
SHA256 79c59e748ed80f307ec13a0926a13c89716ac11f1e7c23541fffe97fef714fac
MD5 270e33f6fd1f3cc12605e815d3a66510
BLAKE2b-256 3fb0dd0eb0f406b41882638984c528bde88f2d5e7c0bed7b5ebde658e2d3a8b4

See more details on using hashes here.

File details

Details for the file generax-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: generax-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 62.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for generax-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 50cc8a2d82b2a7d6fa0b47a91dca98d8693d7943d3db6f8c0f3a7fb2e15e4798
MD5 2cd3b3133d4c7993ff05719e585212e0
BLAKE2b-256 9816c60af9d281f68ed3954a77aff995995bfe55b6db756f7121e63744959e54

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page