Implements adaptive computation time RNNs in PyTorch, with the same interface as builtin RNNs.
Project description
pytorch-adaptive-computation-time
This library implements PyTorch modules for recurrent neural networks that can learn to execute variable-time algorithms, as presented in Adaptive Computation Time for Recurrent Neural Networks (Graves 2016). These models can learn patterns requiring varying amounts of computation for a fixed-size input, which is difficult or impossible for traditional neural networks. The library aims to be clean, idiomatic, and extensible, offering a similar interface to PyTorch’s builtin recurrent modules.
The main features are:
- A nearly drop-in replacement for torch.nn.RNN- and torch.nn.RNNCell-style RNNs, but with the power of variable computation time.
- A wrapper which adds adaptive computation time to any RNNCell.
- Data generators, configs, and training scripts to reproduce experiments from the paper.
Example
Vanilla PyTorch GRU:
rnn = torch.nn.GRU(64, 128, num_layers=2)
output, hidden = rnn(inputs, initial_hidden)
GRU with adaptive computation time:
rnn = models.AdaptiveGRU(64, 128, num_layers=2, time_penalty=1e-3)
output, hidden, ponder_cost = rnn(inputs, initial_hidden)
Documentation
Documentation is hosted on Read the Docs.
BibTeX
You don’t need to cite this code, but if it helps you in your research and you’d like to:
@misc{swope2020ACT,
title = "pytorch-adaptive-computation-time",
author = "Swope, Aidan",
journal = "GitHub",
year = "2020",
url = "https://github.com/maxwells-daemons/pytorch-adaptive-computation-time"
}
If you use the experiment code, please also consider citing PyTorch Lightning.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch-adaptive-computation-time-0.1.3.tar.gz
.
File metadata
- Download URL: pytorch-adaptive-computation-time-0.1.3.tar.gz
- Upload date:
- Size: 14.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.0.5 CPython/3.8.3 Linux/4.18.14-arch1-1-ARCH
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bf5744326e623310db0c048824ff0091d8efd519ed426e64bdae81248f7ba7c3 |
|
MD5 | 0a078dafe735a57582adf66e49f66e94 |
|
BLAKE2b-256 | 28ec36041fb56a08079f6d87bc5f82b6246d74405d7c7f02ec08dd378dab718f |
File details
Details for the file pytorch_adaptive_computation_time-0.1.3-py3-none-any.whl
.
File metadata
- Download URL: pytorch_adaptive_computation_time-0.1.3-py3-none-any.whl
- Upload date:
- Size: 15.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.0.5 CPython/3.8.3 Linux/4.18.14-arch1-1-ARCH
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 85188c0e1cf284772065a2d515db0573f03af2da62268bdc6636f8cf586ff0c3 |
|
MD5 | f814ce644b466b1da921a330aeb8ead6 |
|
BLAKE2b-256 | 5518c0e8e0fd955271094bf7307dbec9daa7e1c96c9cb99978a7c4829226bbf6 |