Implicit and competitive differentiation in JAX.
Project description
fax: fixed-point jax
Implicit and competitive differentiation in JAX.
Our "competitive differentiation" approach uses Competitive Gradient Descent to solve the equality-constrained nonlinear program associated with the fixed-point problem. A standalone implementation of CGD is provided under fax/competitive/cga.py and the equality-constrained solver derived from it can be accessed via fax.constrained.cga_lagrange_min
or fax.constrained.cga_ecp
. An implementation of implicit differentiation based on Christianson's two-phases reverse accumulation algorithm can also be obtained with the function fax.implicit.two_phase_solver
.
See fax/constrained/constrained_test.py for examples. Please note that the API is subject to change.
References
Citing competitive differentiation:
@inproceedings{bacon2019optrl,
author={Pierre-Luc Bacon, Florian Schaefer, Clement Gehring, Animashree Anandkumar, Emma Brunskill},
title={A Lagrangian Method for Inverse Problems in Reinforcement Learning},
booktitle={NeurIPS Optimization Foundations for Reinforcement Learning Workshop},
year={2019},
url={http://lis.csail.mit.edu/pubs/bacon-optrl-2019.pdf},
keywords={Optimization, Reinforcement Learning, Lagrangian}
}
Citing this repo:
@misc{gehring2019fax,
author = {Clement Gehring, Pierre-Luc Bacon, Florian Schaefer},
title = {{FAX: differentiating fixed point problems in JAX}},
note = {Available at: https://github.com/gehring/fax},
year = {2019}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for jax_fixedpoint-0.0.4-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fea60907d928103b3bf470c90a56c1905e9e388e0094770f5b51f71b7ae4ce1c |
|
MD5 | c4e4a08a33fac4eebf47a3de199929cf |
|
BLAKE2b-256 | 704b0661345c8d8a8abf4b8e6ac088ca1cea29a7fb855482424feb0b98be565f |