Differentiable minimization in jax using Newton's method.
Project description
Differentiable minimization in jax using Newton's method
v0.0.0
This project essentially repackages code from the implicit layers tutorial to provide a minimize_newton
function.
Given a function fn(params, z)
, it finds the z_star
which minimizes fn
for given params
. Further, the gradient of the solution with respect to params
can be computed; this is done using a custom vjp rule, as shown in the tutorial.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
mewtax-0.0.0.tar.gz
(6.3 kB
view hashes)
Built Distribution
mewtax-0.0.0-py3-none-any.whl
(5.3 kB
view hashes)