Skip to main content

A Python library implementing the DKS/TAT neural network transformation method.

Project description

CI status pypi

Official Python package for Deep Kernel Shaping (DKS) and Tailored Activation Transformations (TAT)

This Python package implements the activation function transformations, weight initializations, and dataset preprocessing used in Deep Kernel Shaping (DKS) and Tailored Activation Transformations (TAT). DKS and TAT, which were introduced in the DKS paper and TAT paper, are methods for constructing/transforming neural networks to make them much easier to train. For example, these methods can be used in conjunction with K-FAC to train deep vanilla deep convnets (without skip connections or normalization layers) as fast as standard ResNets of the same depth.

The package supports the JAX, PyTorch, and TensorFlow tensor programming frameworks.

Questions/comments about the code can be sent to dks-dev@google.com.

NOTE: we are not taking code contributions from Github at this time. All PRs from Github will be rejected. Instead, please email us if you find a bug.

Usage

For each of the supported tensor programming frameworks, there is a corresponding subpackage which handles the activation function transformations, weight initializations, and (optional) data preprocessing. (These are dks.jax, dks.pytorch, and dks.tensorflow.) It's up to the user to import these and use them appropriately within their model code. Activation functions are transformed by the function get_transformed_activations() in the module activation_transform of the appropriate subpackage. Sampling initial parameters is done using functions inside of the module parameter_sampling_functions of said subpackage. And data preprocessing is done using the function per_location_normalization inside of the module data_preprocessing of said subpackage. Note that in order to avoid having to import all of the tensor programming frameworks, the user is required to individually import whatever framework subpackage they want. e.g. import dks.jax. Meanwhile, import dks won't actually do anything.

get_transformed_activations() requires the user to pass either the "maximal slope function" for DKS, the "subnet maximizing function" for TAT with Leaky ReLUs, or the "maximal curvature function" for TAT with smooth activation functions. (The subnet maximizing function also handles DKS and TAT with smooth activations.) These are special functions that encode information about the particular model architecture. See the section titled "Summary of our method" of the DKS paper for a procedure to construct the maximal slope function for a given model, or the appendix section titled "Additional details and pseudocode for activation function transformations" of the TAT paper for procedures to construct the other two functions.

In addition to these things, the user is responsible for ensuring that their model meets the architectural requirements of DKS/TAT, and for converting any weighted sums into "normalized sums" (which are weighted sums whose non-trainable weights have a sum of squares equal to 1). See the section titled "Summary of our method" of the DKS paper for more details.

Note that the data preprocessing method implemented, called Per-Location Normalization (PLN), may not always be needed in practice, but we have observed certain situations where not using can lead to problems. (For example, training on datasets that contain all-zero pixels, such as CIFAR-10.) Also note that ReLUs are only partially supported by DKS, and unsupported by TAT, and so their use is highly discouraged. Instead, one should use Leaky ReLUs, which are fully supported by DKS, and work especially well with TAT.

Example

dks.examples.haiku.modified_resnet is a Haiku ResNet model which has been modified as described in the DKS/TAT papers, and includes support for both DKS and TAT. When constructed with its default arguments, it removes the normalization layers and skip connections found in standard ResNets, making it a "vanilla network". It can be used as an instructive example for how to build DKS/TAT models using this package. See the section titled "Application to various modified ResNets" from the DKS paper for more details.

Installation

This package can be installed directly from GitHub using pip with

pip install git+https://github.com/deepmind/dks.git

or

pip install -e git+https://github.com/deepmind/dks.git#egg=dks[<extras>]

Or from PyPI with

pip install dks

or

pip install dks[<extras>]

Here <extras> is a common-separated list of strings (with no spaces) that can be passed to install extra dependencies for different tensor programming frameworks. Valid strings are jax, tf, and pytorch. So for example, to install dks with the extra requirements for JAX and PyTorch, one does

pip install dks[jax,pytorch]

Testing

To run tests in a Python virtual environment with specific pinned versions of all the dependencies one can do:

git clone https://github.com/deepmind/dks.git
cd dks
./test.sh

However, it is strongly recommended that you run the tests in the same Python environment (with the same package versions) as you plan to actually use dks. This can be accomplished by installing dks for all three tensors programming frameworks (e.g. with pip install dks[jax,pytorch,tf] or some other installation method), and then doing

pip install pytest-xdist
git clone https://github.com/deepmind/dks.git
cd dks
python -m pytest -n 16 tests

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dks-0.1.2.tar.gz (1.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dks-0.1.2-py3-none-any.whl (1.3 MB view details)

Uploaded Python 3

File details

Details for the file dks-0.1.2.tar.gz.

File metadata

  • Download URL: dks-0.1.2.tar.gz
  • Upload date:
  • Size: 1.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.1

File hashes

Hashes for dks-0.1.2.tar.gz
Algorithm Hash digest
SHA256 9db9ebd670afd535962ed6d3593a9a125b0286749a64660e22693cd2ff3f2508
MD5 7855d4861cddbdbe94fe70b5129d7865
BLAKE2b-256 e2ad3e3958f6551ce272f652e86191293f9fdbcda0eb2d84bec7545541ac2309

See more details on using hashes here.

File details

Details for the file dks-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: dks-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 1.3 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.1

File hashes

Hashes for dks-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 14e7ccf6300371f069f783647f6228e066e8f21124003334b395f17c23be668f
MD5 e4af540f63f07837beda5b453a6209bc
BLAKE2b-256 a43a210ee90cece1f233d16795c14c513b8964750865a9040e769253d5573883

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page