Implementation of Apple ML's TARFlow in JAX.
Project description
Transformer flows
Implementation of Apple ML's Transformer Flow (or TARFlow) from Normalising flows are capable generative models in jax and equinox.
Features:
jax.vmap&jax.lax.scanconstruction & forward-pass, for layers respectively for fast compilation and execution,- multi-device training, inference and sampling,
- score-based denoising step (see paper),
- conditioning via class embedding (for discrete class labels) or adaptive layer-normalisation (for continuous variables, like in DiT),
- array-typed to-the-teeth for dependable execution with
jaxtypingandbeartype.
To implement:
- Guidance
- Denoising
- Mixed precision
- EMA
- AdaLayerNorm
- Class embedding
- Hyperparameter/model saving
- Uniform and Gaussian noise for dequantisation
Usage
pip install -e .
Samples
I haven't optimised anything here (the authors mention varying the variance of noise used to dequantise the images), nor have I trained for very long. You can see slight artifacts due to the dequantisation noise.
Citation
@misc{zhai2024normalizingflowscapablegenerative,
title={Normalizing Flows are Capable Generative Models},
author={Shuangfei Zhai and Ruixiang Zhang and Preetum Nakkiran and David Berthelot and Jiatao Gu and Huangjie Zheng and Tianrong Chen and Miguel Angel Bautista and Navdeep Jaitly and Josh Susskind},
year={2024},
eprint={2412.06329},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2412.06329},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file transformer_flows-0.0.7.tar.gz.
File metadata
- Download URL: transformer_flows-0.0.7.tar.gz
- Upload date:
- Size: 15.4 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.7.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4ac9172fbd75e38441260480b1fdf0f98dc192522e21153f09d75e4f7261112e
|
|
| MD5 |
125d46963a226ded583aed3523fa6407
|
|
| BLAKE2b-256 |
b5cf85059de35dcee1baeafdbf362f9b4217fe51979f175f5bf9bca8f4fc71cd
|
File details
Details for the file transformer_flows-0.0.7-py3-none-any.whl.
File metadata
- Download URL: transformer_flows-0.0.7-py3-none-any.whl
- Upload date:
- Size: 15.4 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.7.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5cce107fc1e69bdeb9b79a24aec8fc56daf970015a1b209f20b49fdc2399f84e
|
|
| MD5 |
a4c41f569500eef7c06fad81e28e178b
|
|
| BLAKE2b-256 |
8a078c46a8b88118f79766951076cfbc8476a0be19f64e14e636efca16303fc7
|