"Automatic Functional Derivative in JAX"
Project description
Automatic Functional Differentiation in JAX
TLDR: We implemented automatic functional differentiation (as in variational calculus) in JAX, one can do g=jax.grad(F)(f)
to get the derivative of the functional F
at function f
, where g
is itself a callable python function.
Installation and Usage
Autofd can be installed via pip
pip install autofd
A minimal example on how to use this package.
import jax
import jax.numpy as jnp
from jaxtyping import Float32, Array
from autofd import function
import autofd.operators as o
# define a function
@function
def f(x: Float32[Array, ""]) -> Float32[Array, ""]:
return -x**2
# define a functional
def F(f):
return o.integrate(o.compose(jnp.exp, f))
# take the functional derivative
dFdf = jax.grad(F)(f)
# dFdf is invokable!
dFdf(1.)
Background
Function generalizes vector
In mathematics, we can see functions as generalizations of vectors. In layman's terms, we can see a vector as a list of bins with different height, e.g. v=[0.34, 0.2, 0.1, 0.43, 0.14]
. This list can be indexed using integers, v[2]
is 0.1
. If we decrease the size of each bin while we increase the number of bins to infinity, eventually we obtain an infinite dimensional vector that can be continuously indexed. In this case when we use $v(x)$ to denote we take the element at position $x$. We can call this infinite dimensional vector a function and taking the element at $x$ becomes a function call.
As we see functions as infinite dimensional vectors, the manipulations that we apply on vectors can also be generalized. For example,
- Summation becomes integration: $\sum_i v_i \rightarrow \int v(x) dx$.
- Difference becomes differentiation: $v[i]-v[i-1] \rightarrow \nabla v(x)$.
- Linear operation: $u_j=\sum_{i}w_{ji}v_i \rightarrow u(y)=\int w(y,x)v(x)dx$.
Function of functions
In JAX, we can easily write python functions that process Array
s that represent vectors or tensors. With the above generalizations, we can also write functions that process infinite dimensional arrays (functions), which we call function of functions, or higher-order functions. There are many higher-order functions in JAX, for example, jax.linearize
, jax.grad
, jax.vjp
etc. Even in pure python, higher-order functions are very common, the decorator pattern in python is implemented via higher-order functions.
Operators and functionals
Functions of functions has many names, generally we call them higher-order functions. Specifically, when the higher-order function maps a function to another function, it is often called an operator, e.g. $\nabla$ is an operator. When the higher-order function maps a function to a scalar, it is often called a functional, e.g. $f\mapsto \int f(x) dx$ is a functional. Conventionally, we use upper-case for higher-order functions, and lower-case for normal functions, like we invoke a function $f$ with input $x$ by $f(x)$, $F(f)$ means we invoke the higher-order function $F$ with input $f$ (In some contexts, square braket like $F[f]$ is used to further denote we're invoking a functional).
Functional derivatives and variational calculus
Just like we can compute the derivative of a function $f$ at a point $x$ by $\frac{df}{dx}$, we can compute the derivative of a functional too. It is denoted as $\frac{\delta F}{\delta f}$. In machine learning we use the derivative information to perform gradient descent, which helps us find the $x$ that minimizes the function $f$. Similarly, we could also use the functional derivative to perform gradient descent for the functional, which gives us the $f$ that minimizes $F$. This procedure is called the calculus of variation.
Technical Details
Please see the paper http://arxiv.org/abs/2311.18727.
@misc{lin2023automatic,
title={Automatic Functional Differentiation in JAX},
author={Min Lin},
year={2023},
eprint={2311.18727},
archivePrefix={arXiv},
primaryClass={cs.PL}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file autofd-0.0.7.tar.gz
.
File metadata
- Download URL: autofd-0.0.7.tar.gz
- Upload date:
- Size: 28.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b28ffc1e4f848368e5eb6710e8bbfc76bf2ea0bce1c3ac9b2c4d4dc1ab82986c |
|
MD5 | 36d002b14feffed30dcabfd866ed9322 |
|
BLAKE2b-256 | 02c662e109d878ee19fef98384c32d0ce81626332d5a03b6fb5938062597adb7 |