Fast and Easy Infinite Neural Networks in Python
Project description
Neural Tangents
ICLR 2020 Video  Paper  Quickstart  Install guide  Reference docs  Release notes
Overview
Neural Tangents is a highlevel neural network API for specifying complex, hierarchical, neural networks of both finite and infinite width. Neural Tangents allows researchers to define, train, and evaluate infinite networks as easily as finite ones.
Infinite (in width or channel count) neural networks are Gaussian Processes (GPs) with a kernel function determined by their architecture (see References for details and nuances of this correspondence).
Neural Tangents allows you to construct a neural network model with the usual building blocks like convolutions, pooling, residual connections, nonlinearities etc. and obtain not only the finite model, but also the kernel function of the respective GP.
The library is written in python using JAX and leveraging XLA to run outofthebox on CPU, GPU, or TPU. Kernel computation is highly optimized for speed and memory efficiency, and can be automatically distributed over multiple accelerators with nearperfect scaling.
Neural Tangents is a work in progress. We happily welcome contributions!
Contents
 Colab Notebooks
 Installation
 5Minute intro
 Package description
 Technical gotchas
 Training dynamics of wide but finite networks
 Performance
 Papers
 Citation
 References
Colab Notebooks
An easy way to get started with Neural Tangents is by playing around with the following interactive notebooks in Colaboratory. They demo the major features of Neural Tangents and show how it can be used in research.
 Neural Tangents Cookbook
 Weight Space Linearization
 Function Space Linearization
 Neural Network Phase Diagram
 Performance Benchmark : Simple benchmark for Myrtle kernels used in [16]. Also see Performance.
Installation
To use GPU, first follow JAX's GPU installation instructions. Otherwise, install JAX on CPU by running
pip install jax jaxlib upgrade
Once JAX is installed install Neural Tangents by running
pip install neuraltangents
or, to use the bleedingedge version from GitHub source,
git clone https://github.com/google/neuraltangents; cd neuraltangents
pip install e .
You can now run the examples (using tensorflow_datasets
)
and tests by calling:
pip install tensorflow tensorflowdatasets moreitertools upgrade
python examples/infinite_fcn.py
python examples/weight_space.py
python examples/function_space.py
set e; for f in tests/*.py; do python $f; done
5Minute intro
See this Colab for a detailed tutorial. Below is a very quick introduction.
Our library closely follows JAX's API for specifying neural networks, stax
. In stax
a network is defined by a pair of functions (init_fn, apply_fn)
initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3layer network and computing it's outputs y
given inputs x
.
from jax import random from jax.experimental import stax init_fn, apply_fn = stax.serial( stax.Dense(512), stax.Relu, stax.Dense(512), stax.Relu, stax.Dense(1) ) key = random.PRNGKey(1) x = random.normal(key, (10, 100)) _, params = init_fn(key, input_shape=x.shape) y = apply_fn(params, x) # (10, 1) np.ndarray outputs of the neural network
Neural Tangents is designed to serve as a dropin replacement for stax
, extending the (init_fn, apply_fn)
tuple to a triple (init_fn, apply_fn, kernel_fn)
, where kernel_fn
is the kernel function of the infinite network (GP) of the given architecture. Below is an example of computing the covariances of the GP between two batches of inputs x1
and x2
.
from jax import random from neural_tangents import stax init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(512), stax.Relu(), stax.Dense(512), stax.Relu(), stax.Dense(1) ) key1, key2 = random.split(random.PRNGKey(1)) x1 = random.normal(key1, (10, 100)) x2 = random.normal(key2, (20, 100)) kernel = kernel_fn(x1, x2, 'nngp')
Note that kernel_fn
can compute two covariance matrices corresponding to the Neural Network Gaussian Process (NNGP) and Neural Tangent (NT) kernels respectively. The NNGP kernel corresponds to the Bayesian infinite neural network [15]. The NTK corresponds to the (continuous) gradient descent trained infinite network [10]. In the above example, we compute the NNGP kernel but we could compute the NTK or both:
# Get kernel of a single type nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) np.ndarray ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) np.ndarray # Get kernels as a namedtuple both = kernel_fn(x1, x2, ('nngp', 'ntk')) both.nngp == nngp # True both.ntk == ntk # True # Unpack the kernels namedtuple nngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk'))
Additionally, if no thirdargument is specified then the kernel_fn
will return a Kernel
namedtuple that contains additional metadata. This can be useful for composing applications of kernel_fn
as follows:
kernel = kernel_fn(x1, x2) kernel = kernel_fn(kernel) print(kernel.nngp)
Doing inference with infinite networks trained on MSE loss reduces to classical GP inference, for which we also provide convenient tools:
import neural_tangents as nt x_train, x_test = x1, x2 y_train = random.uniform(key1, shape=(10, 1)) # training targets predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train) y_test_nngp = predict_fn(x_test=x_test, get='nngp') # (20, 1) np.ndarray test predictions of an infinite Bayesian network y_test_ntk = predict_fn(x_test=x_test, get='ntk') # (20, 1) np.ndarray test predictions of an infinite continuous # gradient descent trained network at convergence (t = inf) # Get predictions as a namedtuple both = predict_fn(x_test=x_test, get=('nngp', 'ntk')) both.nngp == y_test_nngp # True both.ntk == y_test_ntk # True # Unpack the predictions namedtuple y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
Infinitely WideResnet
We can define a more compex, (infinitely) Wide Residual Network [14] using the same nt.stax
building blocks:
from neural_tangents import stax def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): Main = stax.serial( stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME')) Shortcut = stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum()) def WideResnetGroup(n, channels, strides=(1, 1)): blocks = [] blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)] for _ in range(n  1): blocks += [WideResnetBlock(channels, (1, 1))] return stax.serial(*blocks) def WideResnet(block_size, k, num_classes): return stax.serial( stax.Conv(16, (3, 3), padding='SAME'), WideResnetGroup(block_size, int(16 * k)), WideResnetGroup(block_size, int(32 * k), (2, 2)), WideResnetGroup(block_size, int(64 * k), (2, 2)), stax.AvgPool((8, 8)), stax.Flatten(), stax.Dense(num_classes, 1., 0.)) init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)
Package description
The neural_tangents
(nt
) package contains the following modules and functions:

stax
 primitives to construct neural networks likeConv
,Relu
,serial
,parallel
etc. 
predict
 predictions with infinite networks:
predict.gradient_descent_mse
 inference with a single infinite width / linearized network trained on MSE loss with continuous gradient descent for an arbitrary finite or infinite (t=None
) time. Computed in closed form. 
predict.gradient_descent
 inference with a single infinite width / linearized network trained on arbitrary loss with continuous (momentum) gradient descent for an arbitrary finite time. Computed using an ODE solver. 
predict.gradient_descent_mse_ensemble
 inference with an infinite ensemble of infinite width networks, either fully Bayesian (get='nngp'
) or inference with MSE loss using continuous gradient descent (get='ntk'
). Finitetime Bayesian inference (e.g.t=1., get='nngp'
) is interpreted as gradient descent on the top layer only [11], since it converges to exact Gaussian process inference with NNGP (t=None, get='nngp'
). Computed in closed form. 
predict.gp_inference
 exact closed form Gaussian process inference using NNGP (get='nngp'
), NTK (get='ntk'
), or both (get=('nngp', 'ntk')
). Equivalent topredict.gradient_descent_mse_ensemble
witht=None
(infinite training time), but has a slightly different API (accepting precomputed kernel matrixk_train_train
instead ofkernel_fn
andx_train
).


monte_carlo_kernel_fn
 compute a Monte Carlo kernel estimate of any(init_fn, apply_fn)
, not necessarily specified viant.stax
, enabling the kernel computation of infinite networks without closedform expressions. 
Tools to investigate training dynamics of wide but finite neural networks, like
linearize
,taylor_expand
,empirical_kernel_fn
and more. See Training dynamics of wide but finite networks for details.
Technical gotchas
nt.stax
vs jax.experimental.stax
We remark the following differences between our library and the JAX one.
 All
nt.stax
layers are instantiated with a function call, i.e.nt.stax.Relu()
vsjax.experimental.stax.Relu
.  All layers with trainable parameters use the NTK parameterization by default (see [10], Remark 1). However, Dense and Conv layers also support the standard parameterization via a
parameterization
keyword argument (see [15]). nt.stax
andjax.experimental.stax
may have different layers and options available (for examplent.stax
layers supportCIRCULAR
padding, haveLayerNorm
, but noBatchNorm
.).
CPU and TPU performance
For CNNs w/ pooling, our CPU and TPU performance is suboptimal due to low core utilization (1020%, looks like an XLA:CPU issue), and excessive padding respectively. We will look into improving performance, but recommend NVIDIA GPUs in the meantime. See Performance.
Training dynamics of wide but finite networks
The kernel of an infinite network kernel_fn(x1, x2).ntk
combined with nt.predict.gradient_descent_mse
together allow to analytically track the outputs of an infinitely wide neural network trained on MSE loss througout training. Here we discuss the implications for wide but finite neural networks and present tools to study their evolution in weight space (trainable parameters of the network) and function space (outputs of the network).
Weight space
Continuous gradient descent in an infinite network has been shown in [11] to correspond to training a linear (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.
For this, we provide two convenient functions:
nt.linearize
, andnt.taylor_expand
,
which allow to linearize or get an arbitraryorder Taylor expansion of any function apply_fn(params, x)
around some initial parameters params_0
as apply_fn_lin = nt.linearize(apply_fn, params_0)
.
One can use apply_fn_lin(params, x)
exactly as you would any other function
(including as an input to JAX optimizers). This makes it easy to compare the
training trajectory of neural networks with that of its linearization.
Previous theory and experiments have examined the linearization of neural
networks from inputs to logits or preactivations, rather than from inputs to
postactivations which are substantially more nonlinear.
Example:
import jax.numpy as np import neural_tangents as nt def apply_fn(params, x): W, b = params return np.dot(x, W) + b W_0 = np.array([[1., 0.], [0., 1.]]) b_0 = np.zeros((2,)) apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0)) W = np.array([[1.5, 0.2], [0.1, 0.9]]) b = b_0 + 0.2 x = np.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]]) logits = apply_fn_lin((W, b), x) # (3, 2) np.ndarray
Function space:
Outputs of a linearized model evolve identically to those of an infinite one [11] but with a different kernel  specifically, the Neural Tangent Kernel [10] evaluated on the specific apply_fn
of the finite network given specific params_0
that the network is initialized with. For this we provide the nt.empirical_kernel_fn
function that accepts any apply_fn
and returns a kernel_fn(x1, x2, get, params)
that allows to compute the empirical NTK and/or NNGP (based on get
) kernels on specific params
.
Example:
import jax.random as random import jax.numpy as np import neural_tangents as nt def apply_fn(params, x): W, b = params return np.dot(x, W) + b W_0 = np.array([[1., 0.], [0., 1.]]) b_0 = np.zeros((2,)) params = (W_0, b_0) key1, key2 = random.split(random.PRNGKey(1), 2) x_train = random.normal(key1, (3, 2)) x_test = random.normal(key2, (4, 2)) y_train = random.uniform(key1, shape=(3, 2)) kernel_fn = nt.empirical_kernel_fn(apply_fn) ntk_train_train = kernel_fn(x_train, None, 'ntk', params) ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params) mse_predictor = nt.predict.gradient_descent_mse(ntk_train_train, y_train) t = 5. y_train_0 = apply_fn(params, x_train) y_test_0 = apply_fn(params, x_test) y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train) # (3, 2) and (4, 2) np.ndarray train and test outputs after `t` units of time # training with continuous gradient descent
What to Expect
The success or failure of the linear approximation is highly architecture dependent. However, some rules of thumb that we've observed are:

Convergence as the network size increases.

For fullyconnected networks one generally observes very strong agreement by the time the layerwidth is 512 (RMSE of about 0.05 at the end of training).

For convolutional networks one generally observes reasonable agreement agreement by the time the number of channels is 512.


Convergence at small learning rates.
With a new model it is therefore advisable to start with a very large model on a small dataset using a small learning rate.
Performance
In the table below we measure time to compute a single NTK
entry in a 21layer CNN (3x3
filters, no strides, SAME
padding, ReLU
) on inputs of shape 3x32x32
. Precisely:
layers = [] for _ in range(21): layers += [stax.Conv(1, (3, 3), (1, 1), 'SAME'), stax.Relu()]
CNN with pooling
Top layer is stax.GlobalAvgPool()
:
_, _, kernel_fn = stax.serial(*(layers + [stax.GlobalAvgPool()]))
Platform  Precision  Milliseconds / NTK entry  Max batch size (NxN ) 

CPU, >56 cores, >700 Gb RAM  32  112.90  >= 128 
CPU, >56 cores, >700 Gb RAM  64  258.55  95 (fastest  72) 
TPU v2  32/16  3.2550  16 
TPU v3  32/16  2.3022  24 
NVIDIA P100  32  5.9433  26 
NVIDIA P100  64  11.349  18 
NVIDIA V100  32  2.7001  26 
NVIDIA V100  64  6.2058  18 
CNN without pooling
Top layer is stax.Flatten()
:
_, _, kernel_fn = stax.serial(*(layers + [stax.Flatten()]))
Platform  Precision  Milliseconds / NTK entry  Max batch size (NxN ) 

CPU, >56 cores, >700 Gb RAM  32  0.12013  2048 <= N < 4096 (fastest  512) 
CPU, >56 cores, >700 Gb RAM  64  0.3414  2048 <= N < 4096 (fastest  256) 
TPU v2  32/16  0.0015722  512 <= N < 1024 
TPU v3  32/16  0.0010647  512 <= N < 1024 
NVIDIA P100  32  0.015171  512 <= N < 1024 
NVIDIA P100  64  0.019894  512 <= N < 1024 
NVIDIA V100  32  0.0046510  512 <= N < 1024 
NVIDIA V100  64  0.010822  512 <= N < 1024 
Tested using version 0.2.1
. All GPU results are per single accelerator.
Note that runtime is proportional to the depth of your network.
If your performance differs significantly,
please file a bug!
Myrtle network
Colab notebook Performance Benchmark
demonstrates how one would construct and benchmark kernels. To demonstrate
flexibility, we took architecture from [16]
as an example. With NVIDIA V100
64bit precision, nt
took 316/330/508 GPUhours on full 60k CIFAR10 dataset for Myrtle5/7/10 kernels.
Papers
Neural Tangents has been used in the following papers:
 Correlated Weights in Infinite Limits of Deep Convolutional Neural Networks
 Dataset MetaLearning from Kernel RidgeRegression
 Deep learning versus kernel learning: an empirical study of loss landscape geometry and the time evolution of the Neural Tangent Kernel
 Stable ResNet
 LabelAware Neural Tangent Kernel: Toward Better Generalization and Local Elasticity
 Semisupervised Batch Active Learning via Bilevel Optimization
 Temperature check: theory and practice for training models with softmaxcrossentropy losses
 Experimental Design for Overparameterized Learning with Application to Single Shot Deep Active Learning
 How Neural Networks Extrapolate: From Feedforward to Graph Neural Networks
 Exploring the Uncertainty Properties of Neural Networks’ Implicit Priors in the InfiniteWidth Limit
 Cold Posteriors and Aleatoric Uncertainty
 Asymptotics of Wide Convolutional Neural Networks
 Finite Versus Infinite Neural Networks: an Empirical Study
 Bayesian Deep Ensembles via the Neural Tangent Kernel
 The Surprising Simplicity of the EarlyTime Learning Dynamics of Neural Networks
 When Do Neural Networks Outperform Kernel Methods?
 Statistical Mechanics of Generalization in Kernel Regression
 Exact posterior distributions of wide Bayesian neural networks
 Infinite attention: NNGP and NTK for deep attention networks
 Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains
 Finding trainable sparse networks through Neural Tangent Transfer
 Coresets via Bilevel Optimization for Continual Learning and Streaming
 On the Neural Tangent Kernel of Deep Networks with Orthogonal Initialization
 The large learning rate phase of deep learning: the catapult mechanism
 Spectrum Dependent Learning Curves in Kernel Regression and Wide Neural Networks
 Taylorized Training: Towards Better Approximation of Neural Network Training at Finite Width
 On the Infinite Width Limit of Neural Networks with a Standard Parameterization
 Disentangling Trainability and Generalization in Deep Learning
 Information in Infinite Ensembles of InfinitelyWide Neural Networks
 Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel
 Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent
 Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes
Please let us know if you make use of the code in a publication and we'll add it to the list!
Citation
If you use the code in a publication, please cite our ICLR 2020 paper:
@inproceedings{neuraltangents2020,
title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},
author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha SohlDickstein and Samuel S. Schoenholz},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://github.com/google/neuraltangents}
}
References
[1] Priors for Infinite Networks
[2] Exponential expressivity in deep neural networks through transient chaos
[3] Toward deeper understanding of neural networks: The power of initialization and a dual view on expressivity
[4] Deep Information Propagation
[5] Deep Neural Networks as Gaussian Processes
[6] Gaussian Process Behaviour in Wide Deep Neural Networks
[7] Dynamical Isometry and a Mean Field Theory of CNNs: How to Train 10,000Layer Vanilla Convolutional Neural Networks.
[8] Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes
[9] Deep Convolutional Networks as shallow Gaussian Processes
[10] Neural Tangent Kernel: Convergence and Generalization in Neural Networks
[11] Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent
[12] Scaling Limits of Wide Neural Networks with Weight Sharing: Gaussian Process Behavior, Gradient Independence, and Neural Tangent Kernel Derivation
[13] Mean Field Residual Networks: On the Edge of Chaos
[14] Wide Residual Networks
[15] On the Infinite Width Limit of Neural Networks with a Standard Parameterization
[16] Neural Kernels Without Tangents
Project details
Release history Release notifications  RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Filename, size  File type  Python version  Upload date  Hashes 

Filename, size neural_tangents0.3.6py2.py3noneany.whl (116.6 kB)  File type Wheel  Python version py2.py3  Upload date  Hashes View 
Filename, size neuraltangents0.3.6.tar.gz (109.9 kB)  File type Source  Python version None  Upload date  Hashes View 
Hashes for neural_tangents0.3.6py2.py3noneany.whl
Algorithm  Hash digest  

SHA256  575d8a28d11109cd9409abc841d547e080657bd682e48fb49f93725e6eaa692f 

MD5  24d3ca322f2401f00c7f4f0ce471d6a8 

BLAKE2256  9aeeee369786e28e7611553722b3abc77cb5d3a071b32ae19c21d155259de60b 