Implements adaptive computation time RNNs in PyTorch, with the same interface as builtin RNNs.
Project description
pytorch-adaptive-computation-time
This library implements PyTorch modules for recurrent neural networks that can learn to execute variable-time algorithms, as presented in Adaptive Computation Time for Recurrent Neural Networks (Graves 2016). These models can learn patterns requiring varying amounts of computation for a fixed-size input, which is difficult or impossible for traditional neural networks. The library aims to be clean, idiomatic, and extensible, offering a similar interface to PyTorch’s builtin recurrent modules.
The main features are:
- A nearly drop-in replacement for torch.nn.RNN- and torch.nn.RNNCell-style RNNs, but with the power of variable computation time.
- A wrapper which adds adaptive computation time to any RNNCell.
- Data generators, configs, and training scripts to reproduce experiments from the paper.
Example
Vanilla PyTorch GRU:
rnn = torch.nn.GRU(64, 128, num_layers=2)
output, hidden = rnn(inputs, initial_hidden)
GRU with adaptive computation time:
rnn = models.AdaptiveGRU(64, 128, num_layers=2, time_penalty=1e-3)
output, hidden, ponder_cost = rnn(inputs, initial_hidden)
Documentation
Documentation is hosted on Read the Docs.
BibTeX
You don’t need to cite this code, but if it helps you in your research and you’d like to:
@misc{swope2020ACT,
title = "pytorch-adaptive-computation-time",
author = "Swope, Aidan",
journal = "GitHub",
year = "2020",
url = "https://github.com/maxwells-daemons/pytorch-adaptive-computation-time"
}
If you use the experiment code, please also consider citing PyTorch Lightning.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytorch-adaptive-computation-time-0.1.2.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 376e4e3cf66f4122a71da081d42310e14d049d050e9d40ebb832c4ce51c649f2 |
|
MD5 | 2cc70f1d34c59f4c3398d050201f5a14 |
|
BLAKE2b-256 | 9b9cd37ac1a2b26cfd909ecf603a4ad88aa5f22de1692a011286adcc22270d50 |
Hashes for pytorch_adaptive_computation_time-0.1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6596579b2105373898a5759ac1ae34b92c003f84155700a3eaff16b6daf4cd42 |
|
MD5 | 0a7b119917642bd9593a7812723edcaa |
|
BLAKE2b-256 | 2ef3618620e0bc3d6e8cc8a954ac598a7aaf4c587853df6704c7e2fdbeefb7d6 |