Fast and Easy Infinite Neural Networks in Python
Project description
Neural Tangents
Fast and Easy Infinite Neural Networks in Python
News: we'll be at the NeurIPS 2019 Bayesian Deep Learning and Science meets Engineering of Deep Learning workshops, and the Symposium on Advances in Approximate Bayesian Inference. Come tell us about your experience with the library!
Overview
Neural Tangents is a high-level neural network API for specifying complex, hierarchical, neural networks of both finite and infinite width. Neural Tangents allows researchers to define, train, and evaluate infinite networks as easily as finite ones.
Infinite (in width or channel count) neural networks are Gaussian Processes (GPs) with a kernel function determined by their architecture (see References for details and nuances of this correspondence).
Neural Tangents allows you to construct a neural network model with the usual building blocks like convolutions, pooling, residual connections, nonlinearities etc. and obtain not only the finite model, but also the kernel function of the respective GP.
The library is written in python using JAX and leveraging XLA to run out-of-the-box on CPU, GPU, or TPU. Kernel computation is highly optimized for speed and memory efficiency, and can be automatically distributed over multiple accelerators with near-perfect scaling.
Neural Tangents is a work in progress. We happily welcome contributions!
Contents
- Installation
- 5-Minute intro
- Package description
- Technical gotchas
- Training dynamics of wide but finite networks
- Papers
- Citation
- References
Installation
To use GPU, first follow JAX's GPU installation instructions (not necessary for CPU-only usage).
Then either run
pip install neural-tangents
or, to build the bleeding-edge version from source,
git clone https://github.com/google/neural-tangents
pip install -e neural-tangents
You can now run the examples (using tensorflow_datasets
) by calling:
pip install tensorflow tensorflow-datasets
python neural-tangents/examples/infinite_fcn.py
python neural-tangents/examples/weight_space.py
python neural-tangents/examples/function_space.py
Finally, you can run tests by calling:
# NOTE: a few tests will fail without
# pip install tensorflow tensorflow-datasets
for f in neural-tangents/neural_tangents/tests/*.py; do python $f; done
If you would prefer, you can get started without installing by checking out our colab examples:
5-Minute intro
See this Colab for a detailed tutorial. Below is a very quick introduction.
Our library closely follows JAX's API for specifying neural networks, stax
. In stax
a network is defined by a pair of functions (init_fn, apply_fn)
initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3-layer network and computing it's outputs y
given inputs x
.
from jax import random
from jax.experimental import stax
init_fn, apply_fn = stax.serial(
stax.Dense(512), stax.Relu,
stax.Dense(512), stax.Relu,
stax.Dense(1)
)
key = random.PRNGKey(1)
x = random.normal(key, (10, 100))
_, params = init_fn(key, input_shape=x.shape)
y = apply_fn(params, x) # (10, 1) np.ndarray outputs of the neural network
Neural Tangents is designed to serve as a drop-in replacement for stax
, extending the (init_fn, apply_fn)
tuple to a triple (init_fn, apply_fn, kernel_fn)
, where kernel_fn
is the kernel function of the infinite network (GP) of the given architecture. Below is an example of computing the covariances of the GP between two batches of inputs x1
and x2
.
from jax import random
from neural_tangents import stax
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(512), stax.Relu(),
stax.Dense(512), stax.Relu(),
stax.Dense(1)
)
key1, key2 = random.split(random.PRNGKey(1))
x1 = random.normal(key1, (10, 100))
x2 = random.normal(key2, (20, 100))
kernel = kernel_fn(x1, x2, 'nngp')
Note that kernel_fn
can compute two covariance matrices corresponding to the Neural Network Gaussian Process (NNGP) and Neural Tangent (NT) kernels respectively. The NNGP kernel corresponds to the Bayesian infinite neural network [1]. The NTK corresponds to the (continuous) gradient descent trained infinite network [5]. In the above example, we compute the NNGP kernel but we could compute the NTK or both:
# Get kernel of a single type
nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) np.ndarray
ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) np.ndarray
# Get kernels as a namedtuple
both = kernel_fn(x1, x2, ('nngp', 'ntk'))
both.nngp == nngp # True
both.ntk == ntk # True
# Unpack the kernels namedtuple
nngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk'))
# Default is to return ('nngp', 'ntk')
nngp, ntk = kernel_fn(x1, x2)
Additionally, if no third-argument is specified then the kernel_fn
will return a Kernel
namedtuple that contains additional metadata. This can be useful for composing applications of kernel_fn
as follows:
kernel = kernel_fn(x1, x2)
kernel = kernel_fn(kernel)
print(kernel.nngp)
Doing inference with infinite networks trained on MSE loss reduces to classical GP inference, for which we also provide convenient tools:
import neural_tangents as nt
x_train, x_test = x1, x2
y_train = random.uniform(key1, shape=(10, 1)) # training targets
y_test_nngp = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test,
get='nngp')
# (20, 1) np.ndarray test predictions of an infinite Bayesian network
y_test_ntk = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test,
get='ntk')
# (20, 1) np.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)
Infinitely WideResnet
We can define a more compex, (infinitely) Wide Residual Network [8] using the same nt.stax
building blocks:
from neural_tangents import stax
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
Main = stax.serial(
stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
channels, (3, 3), strides, padding='SAME')
return stax.serial(stax.FanOut(2),
stax.parallel(Main, Shortcut),
stax.FanInSum())
def WideResnetGroup(n, channels, strides=(1, 1)):
blocks = []
blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
for _ in range(n - 1):
blocks += [WideResnetBlock(channels, (1, 1))]
return stax.serial(*blocks)
def WideResnet(block_size, k, num_classes):
return stax.serial(
stax.Conv(16, (3, 3), padding='SAME'),
WideResnetGroup(block_size, int(16 * k)),
WideResnetGroup(block_size, int(32 * k), (2, 2)),
WideResnetGroup(block_size, int(64 * k), (2, 2)),
stax.AvgPool((8, 8)),
stax.Flatten(),
stax.Dense(num_classes, 1., 0.))
init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)
Package description
The neural_tangents
(nt
) package contains the following modules and methods:
-
stax
- primitives to construct neural networks likeConv
,Relu
,serial
,parallel
etc. -
predict
- predictions with infinite networks:-
predict.gp_inference
- either fully Bayesian inference (get='nngp'
) or inference with a network trained to full convergence (infinite time) on MSE loss using continuous gradient descent (get='ntk'
). -
predict.gradient_descent_mse
- inference with a network trained on MSE loss with continuous gradient descent for an arbitrary finite time. -
predict.gradient_descent
- inference with a network trained on arbitrary loss with continuous gradient descent for an arbitrary finite time (using an ODE solver). -
predict.momentum
- inference with a network trained on arbitrary loss with continuous momentum gradient descent for an arbitrary finite time (using an ODE solver).
-
-
monte_carlo_kernel_fn
- compute a Monte Carlo kernel estimate of any(init_fn, apply_fn)
, not necessarily specifiednt.stax
, enabling the kernel computation of infinite networks without closed-form expressions. -
Tools to investigate training dynamics of wide but finite neural networks, like
linearize
,taylor_expand
,empirical_kernel_fn
and more. See Training dynamics of wide but finite networks for details.
Technical gotchas
GPU
You must follow JAX's GPU installation instructions to enable GPU support.
64-bit precision
To enable 64-bit precision, set the respective JAX flag before importing neural_tangents
(see the JAX guide), for example:
from jax.config import config
config.update("jax_enable_x64", True)
import neural_tangents as nt # 64-bit precision enabled
nt.stax
vs jax.experimental.stax
We remark the following differences between our library and the JAX one.
- All
nt.stax
layers are instantiated with a function call, i.e.nt.stax.Relu()
vsjax.experimental.stax.Relu
. - All layers with trainable parameters use the NTK parameterization (see [5], Remark 1).
nt.stax
andjax.experimental.stax
may have different layers and options available (for examplent.stax
layers supportCIRCULAR
padding, but onlyNHWC
data format).
Python 2
We will be dropping python 2 support before 2020.
Training dynamics of wide but finite networks
The kernel of an infinite network kernel_fn(x1, x2).ntk
combined with nt.predict.gradient_descent_mse
together allow to analytically track the outputs of an infinitely wide neural network trained on MSE loss througout training. Here we discuss the implications for wide but finite neural networks and present tools to study their evolution in weight space (trainable parameters of the network) and function space (outputs of the network).
Weight space
Continuous gradient descent in an infinite network has been shown in [6] to correspond to training a linear (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.
For this, we provide two convenient methods:
nt.linearize
, andnt.taylor_expand
,
which allow to linearize or get an arbitrary-order Taylor expansion of any function apply_fn(params, x)
around some initial parameters params_0
as apply_fn_lin = nt.linearize(apply_fn, params_0)
.
One can use apply_fn_lin(params, x)
exactly as you would any other function
(including as an input to JAX optimizers). This makes it easy to compare the
training trajectory of neural networks with that of its linearization.
Previous theory and experiments have examined the linearization of neural
networks from inputs to logits or pre-activations, rather than from inputs to
post-activations which are substantially more nonlinear.
Example:
import jax.numpy as np
import neural_tangents as nt
def apply_fn(params, x):
W, b = params
return np.dot(x, W) + b
W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0))
W = np.array([[1.5, 0.2], [0.1, 0.9]])
b = b_0 + 0.2
x = np.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])
logits = apply_fn_lin((W, b), x) # (3, 2) np.ndarray
Function space:
Outputs of a linearized model evolve identically to those of an infinite one [6] but with a different kernel - specifically, the Neural Tangent Kernel [5] evaluated on the specific apply_fn
of the finite network given specific params_0
that the network is initialized with. For this we provide the nt.empirical_kernel_fn
function that accepts any apply_fn
and returns a kernel_fn(x1, x2, params)
that allows to compute the empirical NTK and NNGP kernels on specific params
.
Example:
import jax.random as random
import jax.numpy as np
import neural_tangents as nt
def apply_fn(params, x):
W, b = params
return np.dot(x, W) + b
W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
params = (W_0, b_0)
key1, key2 = random.split(random.PRNGKey(1), 2)
x_train = random.normal(key1, (3, 2))
x_test = random.normal(key2, (4, 2))
y_train = random.uniform(key1, shape=(3, 2))
kernel_fn = nt.empirical_kernel_fn(apply_fn)
ntk_train_train = kernel_fn(x_train, x_train, params, 'ntk')
ntk_test_train = kernel_fn(x_test, x_train, params, 'ntk')
mse_predictor = nt.predict.gradient_descent_mse(
ntk_train_train, y_train, ntk_test_train)
t = 5.
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0)
# (3, 2) and (4, 2) np.ndarray train and test outputs after `t` units of time
# training with continuous gradient descent
What to Expect
The success or failure of the linear approximation is highly architecture dependent. However, some rules of thumb that we've observed are:
-
Convergence as the network size increases.
-
For fully-connected networks one generally observes very strong agreement by the time the layer-width is 512 (RMSE of about 0.05 at the end of training).
-
For convolutional networks one generally observes reasonable agreement agreement by the time the number of channels is 512.
-
-
Convergence at small learning rates.
With a new model it is therefore adviseable to start with a very large model on a small dataset using a small learning rate.
Papers
Neural tangents has been used in the following papers:
-
Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent.
Jaehoon Lee*, Lechao Xiao*, Samuel S. Schoenholz, Yasaman Bahri, Roman Novak, Jascha Sohl-Dickstein, Jeffrey Pennington -
Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel.
Soufiane Hayou, Arnaud Doucet, Judith Rousseau
Please let us know if you make use of the code in a publication and we'll add it to the list!
Citation
If you use the code in a publication, please cite the repo using the .bib,
Coming soon.
References
[1] Deep Neural Networks as Gaussian Processes. ICLR 2018. Jaehoon Lee*, Yasaman Bahri*, Roman Novak, Samuel S. Schoenholz, Jeffrey Pennington, Jascha Sohl-Dickstein
[2] Gaussian Process Behaviour in Wide Deep Neural Networks. ICLR 2018. Alexander G. de G. Matthews, Mark Rowland, Jiri Hron, Richard E. Turner, Zoubin Ghahramani
[3] Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes. ICLR 2019. Roman Novak*, Lechao Xiao*, Jaehoon Lee, Yasaman Bahri, Greg Yang, Jiri Hron, Daniel A. Abolafia, Jeffrey Pennington, Jascha Sohl-Dickstein
[4] Deep Convolutional Networks as shallow Gaussian Processes. ICLR 2019. Adrià Garriga-Alonso, Carl Edward Rasmussen, Laurence Aitchison
[5] Neural Tangent Kernel: Convergence and Generalization in Neural Networks. NeurIPS 2018. Arthur Jacot, Franck Gabriel, Clément Hongler
[6] Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent. NeurIPS 2019. Jaehoon Lee*, Lechao Xiao*, Samuel S. Schoenholz, Yasaman Bahri, Roman Novak, Jascha Sohl-Dickstein, Jeffrey Pennington
[7] Scaling Limits of Wide Neural Networks with Weight Sharing: Gaussian Process Behavior, Gradient Independence, and Neural Tangent Kernel Derivation. arXiv 2019. Greg Yang
[8] Wide Residual Networks. BMVC 2018. Sergey Zagoruyko, Nikos Komodakis
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for neural_tangents-0.1.2-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7b53e2b9db4110ebc6f1b337b068aeea6a966d66f6d657bd88e37fc9d45bc022 |
|
MD5 | 128ef9034444c3c900a9da069d96141b |
|
BLAKE2b-256 | f66a89d3708d4e9d3df6b2058cd573eeeeeb87564108a0b8d5f648bb66c01f3f |