Skip to main content

"Automatic Functional Derivative in JAX"

Project description

Automatic Functional Differentiation in JAX

TLDR: We implemented automatic functional differentiation (as in variational calculus) in JAX, one can do g=jax.grad(F)(f) to get the derivative of the functional F at function f, where g is itself a callable python function.

Installation and Usage

Autofd can be installed via pip

pip install autofd

A minimal example on how to use this package.

import jax
import jax.numpy as jnp
from jaxtyping import Float32, Array
from autofd import function
import autofd.operators as o

# define a function
@function
def f(x: Float32[Array, ""]) -> Float32[Array, ""]:
  return -x**2

# define a functional
def F(f):
  return o.integrate(o.compose(jnp.exp, f))

# take the functional derivative
dFdf = jax.grad(F)(f)

# dFdf is invokable!
dFdf(1.)

Background

Function generalizes vector

In mathematics, we can see functions as generalizations of vectors. In layman's terms, we can see a vector as a list of bins with different height, e.g. v=[0.34, 0.2, 0.1, 0.43, 0.14]. This list can be indexed using integers, v[2] is 0.1. If we decrease the size of each bin while we increase the number of bins to infinity, eventually we obtain an infinite dimensional vector that can be continuously indexed. In this case when we use $v(x)$ to denote we take the element at position $x$. We can call this infinite dimensional vector a function and taking the element at $x$ becomes a function call.

As we see functions as infinite dimensional vectors, the manipulations that we apply on vectors can also be generalized. For example,

  • Summation becomes integration: $\sum_i v_i \rightarrow \int v(x) dx$.
  • Difference becomes differentiation: $v[i]-v[i-1] \rightarrow \nabla v(x)$.
  • Linear operation: $u_j=\sum_{i}w_{ji}v_i \rightarrow u(y)=\int w(y,x)v(x)dx$.

Function of functions

In JAX, we can easily write python functions that process Arrays that represent vectors or tensors. With the above generalizations, we can also write functions that process infinite dimensional arrays (functions), which we call function of functions, or higher-order functions. There are many higher-order functions in JAX, for example, jax.linearize, jax.grad, jax.vjp etc. Even in pure python, higher-order functions are very common, the decorator pattern in python is implemented via higher-order functions.

Operators and functionals

Functions of functions has many names, generally we call them higher-order functions. Specifically, when the higher-order function maps a function to another function, it is often called an operator, e.g. $\nabla$ is an operator. When the higher-order function maps a function to a scalar, it is often called a functional, e.g. $f\mapsto \int f(x) dx$ is a functional. Conventionally, we use upper-case for higher-order functions, and lower-case for normal functions, like we invoke a function $f$ with input $x$ by $f(x)$, $F(f)$ means we invoke the higher-order function $F$ with input $f$ (In some contexts, square braket like $F[f]$ is used to further denote we're invoking a functional).

Functional derivatives and variational calculus

Just like we can compute the derivative of a function $f$ at a point $x$ by $\frac{df}{dx}$, we can compute the derivative of a functional too. It is denoted as $\frac{\delta F}{\delta f}$. In machine learning we use the derivative information to perform gradient descent, which helps us find the $x$ that minimizes the function $f$. Similarly, we could also use the functional derivative to perform gradient descent for the functional, which gives us the $f$ that minimizes $F$. This procedure is called the calculus of variation.

Technical Details

Please see the paper http://arxiv.org/abs/2311.18727.

@misc{lin2023automatic,
      title={Automatic Functional Differentiation in JAX},
      author={Min Lin},
      year={2023},
      eprint={2311.18727},
      archivePrefix={arXiv},
      primaryClass={cs.PL}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autofd-0.0.7.tar.gz (28.3 kB view details)

Uploaded Source

File details

Details for the file autofd-0.0.7.tar.gz.

File metadata

  • Download URL: autofd-0.0.7.tar.gz
  • Upload date:
  • Size: 28.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for autofd-0.0.7.tar.gz
Algorithm Hash digest
SHA256 b28ffc1e4f848368e5eb6710e8bbfc76bf2ea0bce1c3ac9b2c4d4dc1ab82986c
MD5 36d002b14feffed30dcabfd866ed9322
BLAKE2b-256 02c662e109d878ee19fef98384c32d0ce81626332d5a03b6fb5938062597adb7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page