Skip to main content

Fast and Easy Infinite Neural Networks in Python

Project description

WARNING: Our next major release (v0.5.0; already in GitHub) will include significant refactoring, and could break your code if you use internal function like nt.utils.typing, nt.utils.utils, nt.utils.Kernel etc. (public API will remain unchanged). This should be easily fixed by updating the imports, e.g. nt.utils -> nt._src.utils.

Neural Tangents

ICLR 2020 Video | Paper | Quickstart | Install guide | Reference docs | Release notes

PyPI PyPI - Python Version Linux macOS Pytype Coverage Readthedocs


Neural Tangents is a high-level neural network API for specifying complex, hierarchical, neural networks of both finite and infinite width. Neural Tangents allows researchers to define, train, and evaluate infinite networks as easily as finite ones.

Infinite (in width or channel count) neural networks are Gaussian Processes (GPs) with a kernel function determined by their architecture. See References for details and nuances of this correspondence. Also see this listing of papers written by the creators of Neural Tangents which study the infinite width limit of neural networks.

Neural Tangents allows you to construct a neural network model from common building blocks like convolutions, pooling, residual connections, nonlinearities, and more, and obtain not only the finite model, but also the kernel function of the respective GP.

The library is written in python using JAX and leveraging XLA to run out-of-the-box on CPU, GPU, or TPU. Kernel computation is highly optimized for speed and memory efficiency, and can be automatically distributed over multiple accelerators with near-perfect scaling.

Neural Tangents is a work in progress. We happily welcome contributions!


Colab Notebooks

An easy way to get started with Neural Tangents is by playing around with the following interactive notebooks in Colaboratory. They demo the major features of Neural Tangents and show how it can be used in research.


To use GPU, first follow JAX's GPU installation instructions. Otherwise, install JAX on CPU by running

pip install jax jaxlib --upgrade

Once JAX is installed install Neural Tangents by running

pip install neural-tangents

or, to use the bleeding-edge version from GitHub source,

git clone; cd neural-tangents
pip install -e .

You can now run the examples and tests by calling:

pip install .[testing]

python examples/
python examples/
python examples/
python examples/

set -e; for f in tests/*.py; do python $f; done

5-Minute intro

See this Colab for a detailed tutorial. Below is a very quick introduction.

Our library closely follows JAX's API for specifying neural networks, stax. In stax a network is defined by a pair of functions (init_fn, apply_fn) initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3-layer network and computing its outputs y given inputs x.

from jax import random
from jax.example_libraries import stax

init_fn, apply_fn = stax.serial(
    stax.Dense(512), stax.Relu,
    stax.Dense(512), stax.Relu,

key = random.PRNGKey(1)
x = random.normal(key, (10, 100))
_, params = init_fn(key, input_shape=x.shape)

y = apply_fn(params, x)  # (10, 1) np.ndarray outputs of the neural network

Neural Tangents is designed to serve as a drop-in replacement for stax, extending the (init_fn, apply_fn) tuple to a triple (init_fn, apply_fn, kernel_fn), where kernel_fn is the kernel function of the infinite network (GP) of the given architecture. Below is an example of computing the covariances of the GP between two batches of inputs x1 and x2.

from jax import random
from neural_tangents import stax

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512), stax.Relu(),
    stax.Dense(512), stax.Relu(),

key1, key2 = random.split(random.PRNGKey(1))
x1 = random.normal(key1, (10, 100))
x2 = random.normal(key2, (20, 100))

kernel = kernel_fn(x1, x2, 'nngp')

Note that kernel_fn can compute two covariance matrices corresponding to the Neural Network Gaussian Process (NNGP) and Neural Tangent (NT) kernels respectively. The NNGP kernel corresponds to the Bayesian infinite neural network [1-5]. The NTK corresponds to the (continuous) gradient descent trained infinite network [10]. In the above example, we compute the NNGP kernel, but we could compute the NTK or both:

# Get kernel of a single type
nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) np.ndarray
ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) np.ndarray

# Get kernels as a namedtuple
both = kernel_fn(x1, x2, ('nngp', 'ntk'))
both.nngp == nngp  # True
both.ntk == ntk  # True

# Unpack the kernels namedtuple
nngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk'))

Additionally, if no third-argument is specified then the kernel_fn will return a Kernel namedtuple that contains additional metadata. This can be useful for composing applications of kernel_fn as follows:

kernel = kernel_fn(x1, x2)
kernel = kernel_fn(kernel)

Doing inference with infinite networks trained on MSE loss reduces to classical GP inference, for which we also provide convenient tools:

import neural_tangents as nt

x_train, x_test = x1, x2
y_train = random.uniform(key1, shape=(10, 1))  # training targets

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,

y_test_nngp = predict_fn(x_test=x_test, get='nngp')
# (20, 1) np.ndarray test predictions of an infinite Bayesian network

y_test_ntk = predict_fn(x_test=x_test, get='ntk')
# (20, 1) np.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)

# Get predictions as a namedtuple
both = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
both.nngp == y_test_nngp  # True
both.ntk == y_test_ntk  # True

# Unpack the predictions namedtuple
y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'))

Infinitely WideResnet

We can define a more compex, (infinitely) Wide Residual Network [14] using the same nt.stax building blocks:

from neural_tangents import stax

def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
  Main = stax.serial(
      stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
      stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
  Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
      channels, (3, 3), strides, padding='SAME')
  return stax.serial(stax.FanOut(2),
                     stax.parallel(Main, Shortcut),

def WideResnetGroup(n, channels, strides=(1, 1)):
  blocks = []
  blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
  for _ in range(n - 1):
    blocks += [WideResnetBlock(channels, (1, 1))]
  return stax.serial(*blocks)

def WideResnet(block_size, k, num_classes):
  return stax.serial(
      stax.Conv(16, (3, 3), padding='SAME'),
      WideResnetGroup(block_size, int(16 * k)),
      WideResnetGroup(block_size, int(32 * k), (2, 2)),
      WideResnetGroup(block_size, int(64 * k), (2, 2)),
      stax.AvgPool((8, 8)),
      stax.Dense(num_classes, 1., 0.))

init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)

Package description

The neural_tangents (nt) package contains the following modules and functions:

  • stax - primitives to construct neural networks like Conv, Relu, serial, parallel etc.

  • predict - predictions with infinite networks:

    • predict.gradient_descent_mse - inference with a single infinite width / linearized network trained on MSE loss with continuous gradient descent for an arbitrary finite or infinite (t=None) time. Computed in closed form.

    • predict.gradient_descent - inference with a single infinite width / linearized network trained on arbitrary loss with continuous (momentum) gradient descent for an arbitrary finite time. Computed using an ODE solver.

    • predict.gradient_descent_mse_ensemble - inference with an infinite ensemble of infinite width networks, either fully Bayesian (get='nngp') or inference with MSE loss using continuous gradient descent (get='ntk'). Finite-time Bayesian inference (e.g. t=1., get='nngp') is interpreted as gradient descent on the top layer only [11], since it converges to exact Gaussian process inference with NNGP (t=None, get='nngp'). Computed in closed form.

    • predict.gp_inference - exact closed form Gaussian process inference using NNGP (get='nngp'), NTK (get='ntk'), or both (get=('nngp', 'ntk')). Equivalent to predict.gradient_descent_mse_ensemble with t=None (infinite training time), but has a slightly different API (accepting precomputed kernel matrix k_train_train instead of kernel_fn and x_train).

  • monte_carlo_kernel_fn - compute a Monte Carlo kernel estimate of any (init_fn, apply_fn), not necessarily specified via nt.stax, enabling the kernel computation of infinite networks without closed-form expressions.

  • Tools to investigate training dynamics of wide but finite neural networks, like linearize, taylor_expand, empirical.kernel_fn and more. See Training dynamics of wide but finite networks for details.

Technical gotchas

nt.stax vs jax.example_libraries.stax

We remark the following differences between our library and the JAX one.

  • All nt.stax layers are instantiated with a function call, i.e. nt.stax.Relu() vs jax.example_libraries.stax.Relu.
  • All layers with trainable parameters use the NTK parameterization by default (see [10], Remark 1). However, Dense and Conv layers also support the standard parameterization via a parameterization keyword argument (see [15]).
  • nt.stax and jax.example_libraries.stax may have different layers and options available (for example nt.stax layers support CIRCULAR padding, have LayerNorm, but no BatchNorm.).

CPU and TPU performance

For CNNs w/ pooling, our CPU and TPU performance is suboptimal due to low core utilization (10-20%, looks like an XLA:CPU issue), and excessive padding respectively. We will look into improving performance, but recommend NVIDIA GPUs in the meantime. See Performance.

Training dynamics of wide but finite networks

The kernel of an infinite network kernel_fn(x1, x2).ntk combined with nt.predict.gradient_descent_mse together allow to analytically track the outputs of an infinitely wide neural network trained on MSE loss througout training. Here we discuss the implications for wide but finite neural networks and present tools to study their evolution in weight space (trainable parameters of the network) and function space (outputs of the network).

Weight space

Continuous gradient descent in an infinite network has been shown in [11] to correspond to training a linear (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.

For this, we provide two convenient functions:

  • nt.linearize, and
  • nt.taylor_expand,

which allow us to linearize or get an arbitrary-order Taylor expansion of any function apply_fn(params, x) around some initial parameters params_0 as apply_fn_lin = nt.linearize(apply_fn, params_0).

One can use apply_fn_lin(params, x) exactly as you would any other function (including as an input to JAX optimizers). This makes it easy to compare the training trajectory of neural networks with that of its linearization. Prior theory and experiments have examined the linearization of neural networks from inputs to logits or pre-activations, rather than from inputs to post-activations which are substantially more nonlinear.


import jax.numpy as np
import neural_tangents as nt

def apply_fn(params, x):
  W, b = params
  return, W) + b

W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))

apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0))
W = np.array([[1.5, 0.2], [0.1, 0.9]])
b = b_0 + 0.2

x = np.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])
logits = apply_fn_lin((W, b), x)  # (3, 2) np.ndarray

Function space:

Outputs of a linearized model evolve identically to those of an infinite one [11] but with a different kernel - specifically, the Neural Tangent Kernel [10] evaluated on the specific apply_fn of the finite network given specific params_0 that the network is initialized with. For this we provide the nt.empirical_kernel_fn function that accepts any apply_fn and returns a kernel_fn(x1, x2, get, params) that allows to compute the empirical NTK and/or NNGP (based on get) kernels on specific params.


import jax.random as random
import jax.numpy as np
import neural_tangents as nt

def apply_fn(params, x):
  W, b = params
  return, W) + b

W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
params = (W_0, b_0)

key1, key2 = random.split(random.PRNGKey(1), 2)
x_train = random.normal(key1, (3, 2))
x_test = random.normal(key2, (4, 2))
y_train = random.uniform(key1, shape=(3, 2))

kernel_fn = nt.empirical_kernel_fn(apply_fn)
ntk_train_train = kernel_fn(x_train, None, 'ntk', params)
ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
mse_predictor = nt.predict.gradient_descent_mse(ntk_train_train, y_train)

t = 5.
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train)
# (3, 2) and (4, 2) np.ndarray train and test outputs after `t` units of time
# training with continuous gradient descent

What to Expect

The success or failure of the linear approximation is highly architecture dependent. However, some rules of thumb that we've observed are:

  1. Convergence as the network size increases.

    • For fully-connected networks one generally observes very strong agreement by the time the layer-width is 512 (RMSE of about 0.05 at the end of training).

    • For convolutional networks one generally observes reasonable agreement by the time the number of channels is 512.

  2. Convergence at small learning rates.

With a new model it is therefore advisable to start with a very large model on a small dataset using a small learning rate.


In the table below we measure time to compute a single NTK entry in a 21-layer CNN (3x3 filters, no strides, SAME padding, ReLU) on inputs of shape 3x32x32. Precisely:

layers = []
for _ in range(21):
  layers += [stax.Conv(1, (3, 3), (1, 1), 'SAME'), stax.Relu()]

CNN with pooling

Top layer is stax.GlobalAvgPool():

_, _, kernel_fn = stax.serial(*(layers + [stax.GlobalAvgPool()]))
Platform Precision Milliseconds / NTK entry Max batch size (NxN)
CPU, >56 cores, >700 Gb RAM 32 112.90 >= 128
CPU, >56 cores, >700 Gb RAM 64 258.55 95 (fastest - 72)
TPU v2 32/16 3.2550 16
TPU v3 32/16 2.3022 24
NVIDIA P100 32 5.9433 26
NVIDIA P100 64 11.349 18
NVIDIA V100 32 2.7001 26
NVIDIA V100 64 6.2058 18

CNN without pooling

Top layer is stax.Flatten():

_, _, kernel_fn = stax.serial(*(layers + [stax.Flatten()]))
Platform Precision Milliseconds / NTK entry Max batch size (NxN)
CPU, >56 cores, >700 Gb RAM 32 0.12013 2048 <= N < 4096 (fastest - 512)
CPU, >56 cores, >700 Gb RAM 64 0.3414 2048 <= N < 4096 (fastest - 256)
TPU v2 32/16 0.0015722 512 <= N < 1024
TPU v3 32/16 0.0010647 512 <= N < 1024
NVIDIA P100 32 0.015171 512 <= N < 1024
NVIDIA P100 64 0.019894 512 <= N < 1024
NVIDIA V100 32 0.0046510 512 <= N < 1024
NVIDIA V100 64 0.010822 512 <= N < 1024

Tested using version 0.2.1. All GPU results are per single accelerator. Note that runtime is proportional to the depth of your network. If your performance differs significantly, please file a bug!

Myrtle network

Colab notebook Performance Benchmark demonstrates how one would construct and benchmark kernels. To demonstrate flexibility, we took architecture from [16] as an example. With NVIDIA V100 64-bit precision, nt took 316/330/508 GPU-hours on full 60k CIFAR-10 dataset for Myrtle-5/7/10 kernels.


Neural Tangents has been used in the following papers (newest first):

  1. Do autoencoders need a bottleneck for anomaly detection?
  2. Finding Dynamics Preserving Adversarial Winning Tickets
  3. Learning Representation from Neural Fisher Kernel with Low-rank Approximation
  4. MIT 6.S088 Modern Machine Learning: Simple Methods that Work
  5. A Neural Tangent Kernel Perspective on Function-Space Regularization in Neural Networks
  6. Eigenspace Restructuring: a Principle of Space and Frequency in Neural Networks
  7. Functional Regularization for Reinforcement Learning via Learned Fourier Features
  8. A Structured Dictionary Perspective on Implicit Neural Representations
  9. Critical initialization of wide and deep neural networks through partial Jacobians: general theory and applications to LayerNorm
  10. Asymptotics of representation learning in finite Bayesian neural networks
  11. On the Equivalence between Neural Network and Support Vector Machine
  12. An Empirical Study of Neural Kernel Bandits
  13. Neural Networks as Kernel Learners: The Silent Alignment Effect
  14. Understanding Deep Learning via Analyzing Dynamics of Gradient Descent
  15. Neural Scene Representations for View Synthesis
  16. Neural Tangent Kernel Eigenvalues Accurately Predict Generalization
  17. Uniform Generalization Bounds for Overparameterized Neural Networks
  18. Data Summarization via Bilevel Optimization
  19. Neural Tangent Generalization Attacks
  20. Dataset Distillation with Infinitely Wide Convolutional Networks
  21. Neural Contextual Bandits without Regret
  22. Epistemic Neural Networks
  23. Uncertainty-aware Cardinality Estimation by Neural Network Gaussian Process
  24. Scale Mixtures of Neural Network Gaussian Processes
  25. Provably efficient machine learning for quantum many-body problems
  26. Wide Mean-Field Variational Bayesian Neural Networks Ignore the Data
  27. Spectral bias and task-model alignment explain generalization in kernel regression and infinitely wide neural networks
  28. Bridging Multi-Task Learning and Meta-Learning: Towards Efficient Training and Effective Adaptation
  29. Wide Mean-Field Variational Bayesian Neural Networks Ignore the Data
  30. What can linearized neural networks actually say about generalization?
  31. Measuring the sensitivity of Gaussian processes to kernel choice
  32. A Neural Tangent Kernel Perspective of GANs
  33. On the Power of Shallow Learning
  34. Learning Curves for SGD on Structured Features
  35. Out-of-Distribution Generalization in Kernel Regression
  36. Rapid Feature Evolution Accelerates Learning in Neural Networks
  37. Scalable and Flexible Deep Bayesian Optimization with Auxiliary Information for Scientific Problems
  38. Random Features for the Neural Tangent Kernel
  39. Multi-Level Fine-Tuning: Closing Generalization Gaps in Approximation of Solution Maps under a Limited Budget for Training Data
  40. Explaining Neural Scaling Laws
  41. Correlated Weights in Infinite Limits of Deep Convolutional Neural Networks
  42. Dataset Meta-Learning from Kernel Ridge-Regression
  43. Deep learning versus kernel learning: an empirical study of loss landscape geometry and the time evolution of the Neural Tangent Kernel
  44. Stable ResNet
  45. Label-Aware Neural Tangent Kernel: Toward Better Generalization and Local Elasticity
  46. Semi-supervised Batch Active Learning via Bilevel Optimization
  47. Temperature check: theory and practice for training models with softmax-cross-entropy losses
  48. Experimental Design for Overparameterized Learning with Application to Single Shot Deep Active Learning
  49. How Neural Networks Extrapolate: From Feedforward to Graph Neural Networks
  50. Exploring the Uncertainty Properties of Neural Networks’ Implicit Priors in the Infinite-Width Limit
  51. Cold Posteriors and Aleatoric Uncertainty
  52. Asymptotics of Wide Convolutional Neural Networks
  53. Finite Versus Infinite Neural Networks: an Empirical Study
  54. Bayesian Deep Ensembles via the Neural Tangent Kernel
  55. The Surprising Simplicity of the Early-Time Learning Dynamics of Neural Networks
  56. When Do Neural Networks Outperform Kernel Methods?
  57. Statistical Mechanics of Generalization in Kernel Regression
  58. Exact posterior distributions of wide Bayesian neural networks
  59. Infinite attention: NNGP and NTK for deep attention networks
  60. Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains
  61. Finding trainable sparse networks through Neural Tangent Transfer
  62. Coresets via Bilevel Optimization for Continual Learning and Streaming
  63. On the Neural Tangent Kernel of Deep Networks with Orthogonal Initialization
  64. The large learning rate phase of deep learning: the catapult mechanism
  65. Spectrum Dependent Learning Curves in Kernel Regression and Wide Neural Networks
  66. Taylorized Training: Towards Better Approximation of Neural Network Training at Finite Width
  67. On the Infinite Width Limit of Neural Networks with a Standard Parameterization
  68. Disentangling Trainability and Generalization in Deep Learning
  69. Information in Infinite Ensembles of Infinitely-Wide Neural Networks
  70. Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel
  71. Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent
  72. Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes

Please let us know if you make use of the code in a publication, and we'll add it to the list!


If you use the code in a publication, please cite our ICLR 2020 paper:

    title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},
    author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
    booktitle={International Conference on Learning Representations},


[1] Priors for Infinite Networks
[2] Exponential expressivity in deep neural networks through transient chaos
[3] Toward deeper understanding of neural networks: The power of initialization and a dual view on expressivity
[4] Deep Information Propagation
[5] Deep Neural Networks as Gaussian Processes
[6] Gaussian Process Behaviour in Wide Deep Neural Networks
[7] Dynamical Isometry and a Mean Field Theory of CNNs: How to Train 10,000-Layer Vanilla Convolutional Neural Networks.
[8] Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes
[9] Deep Convolutional Networks as shallow Gaussian Processes
[10] Neural Tangent Kernel: Convergence and Generalization in Neural Networks
[11] Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent
[12] Scaling Limits of Wide Neural Networks with Weight Sharing: Gaussian Process Behavior, Gradient Independence, and Neural Tangent Kernel Derivation
[13] Mean Field Residual Networks: On the Edge of Chaos
[14] Wide Residual Networks
[15] On the Infinite Width Limit of Neural Networks with a Standard Parameterization
[16] Neural Kernels Without Tangents

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neural-tangents-0.5.0.tar.gz (176.8 kB view hashes)

Uploaded source

Built Distribution

neural_tangents-0.5.0-py2.py3-none-any.whl (193.4 kB view hashes)

Uploaded py2 py3

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page