Skip to main content

ReZero networks

Project description

ReZero for Deep Neural Networks

ReZero is All You Need: Fast Convergence at Large Depth; ArXiv, March 2020.

Thomas Bachlechner*, Bodhisattwa Prasad Majumder*, Huanru Henry Mao*, Garrison W. Cottrell, Julian McAuley (* denotes equal contributions)

This repository contains the ReZero-Transformer implementation from the paper. It matches Pytorch's Transformer and can be easily used as a drop-in replacement. See sections below for installation and usage.

Abstract

Deep networks have enabled significant performance gains across domains, but they often suffer from vanishing/exploding gradients. This is especially true for Transformer architectures where depth beyond 12 layers is difficult to train without large datasets and computational budgets. In general, we find that inefficient signal propagation impedes learning in deep networks. In Transformers, multi-head self-attention is the main cause of this poor signal propagation. To facilitate deep signal propagation, we propose ReZero, a simple change to the architecture that initializes an arbitrary layer as the identity map, using a single additional learned parameter per layer. We apply this technique to language modeling and find that we can easily train ReZero-Transformer networks over a hundred layers. When applied to 12 layer Transformers, ReZero converges 56% faster on enwiki8. ReZero applies beyond Transformers to other residual networks, enabling 1,500% faster convergence for deep fully connected networks and 32% faster convergence for a ResNet-56 trained on CIFAR 10.

Installation

Simply install from pip:

pip install rezero

Pytorch 1.4 or greater is required.

Usage

We provide custom ReZero Transformer layers (RZTX).

For example, this will create a Transformer encoder:

import torch
import torch.nn as nn
from rezero.transformer import RZTXEncoderLayer

encoder_layer = RZTXEncoderLayer(d_model=512, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
src = torch.rand(10, 32, 512)
out = transformer_encoder(src)

This will create a Transformer decoder:

import torch
import torch.nn as nn
from rezero.transformer import RZTXDecoderLayer

decoder_layer = RZTXDecoderLayer(d_model=512, nhead=8)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
memory = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
out = transformer_decoder(tgt, memory)

Make sure norm argument is left as None as to not use LayerNorm in the Transformer.

See https://pytorch.org/docs/master/nn.html#torch.nn.Transformer for details on how to integrate customer Transformer layers to Pytorch.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rezero-0.1.0.tar.gz (3.9 kB view details)

Uploaded Source

Built Distribution

rezero-0.1.0-py3-none-any.whl (5.4 kB view details)

Uploaded Python 3

File details

Details for the file rezero-0.1.0.tar.gz.

File metadata

  • Download URL: rezero-0.1.0.tar.gz
  • Upload date:
  • Size: 3.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.0

File hashes

Hashes for rezero-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d89916e9ba26688b6e105caaa017569c8d29e6d5004810e680cc3f7cdc2ce7f2
MD5 70749a045feacc83803114961bf44c8e
BLAKE2b-256 25a8a997ffb4e407727f88679c554f8121b0cfa484a02b4d6852a9679faa685d

See more details on using hashes here.

File details

Details for the file rezero-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: rezero-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 5.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.0

File hashes

Hashes for rezero-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 edc9816fae53f928f4d187c8378e371ae98f854f19eb9aa75ec3aeaef4081a78
MD5 29b8e9c2198cdaa312c9f8543ab0f3e4
BLAKE2b-256 ecd24751a70110fc219e6abe056f3cfe261bef8dfa03cbc20a7c0adaf25aef78

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page