Implements adaptive computation time RNNs in PyTorch, with the same interface as builtin RNNs.
Project description
pytorch-adaptive-computation-time
This library implements PyTorch modules for recurrent neural networks that can learn to execute variable-time algorithms, as presented in Adaptive Computation Time for Recurrent Neural Networks (Graves 2016). These models can learn patterns requiring varying amounts of computation for a fixed-size input, which is difficult or impossible for traditional neural networks. The library aims to be clean, idiomatic, and extensible, offering a similar interface to PyTorch’s builtin recurrent modules.
The main features are:
- A nearly drop-in replacement for torch.nn.RNN- and torch.nn.RNNCell-style RNNs, but with the power of variable computation time.
- A wrapper which adds adaptive computation time to any RNNCell.
- Data generators, configs, and training scripts to reproduce experiments from the paper.
Example
Vanilla PyTorch GRU:
rnn = torch.nn.GRU(64, 128, num_layers=2)
output, hidden = rnn(inputs, initial_hidden)
GRU with adaptive computation time:
rnn = models.AdaptiveGRU(64, 128, num_layers=2, time_penalty=1e-3)
output, hidden, ponder_cost = rnn(inputs, initial_hidden)
Documentation
Documentation is hosted on Read the Docs.
BibTeX
You don’t need to cite this code, but if it helps you in your research and you’d like to:
@misc{swope2020ACT,
title = "pytorch-adaptive-computation-time",
author = "Swope, Aidan",
journal = "GitHub",
year = "2020",
url = "https://github.com/maxwells-daemons/pytorch-adaptive-computation-time"
}
If you use the experiment code, please also consider citing PyTorch Lightning.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytorch-adaptive-computation-time-0.1.3.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | bf5744326e623310db0c048824ff0091d8efd519ed426e64bdae81248f7ba7c3 |
|
MD5 | 0a078dafe735a57582adf66e49f66e94 |
|
BLAKE2b-256 | 28ec36041fb56a08079f6d87bc5f82b6246d74405d7c7f02ec08dd378dab718f |
Hashes for pytorch_adaptive_computation_time-0.1.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 85188c0e1cf284772065a2d515db0573f03af2da62268bdc6636f8cf586ff0c3 |
|
MD5 | f814ce644b466b1da921a330aeb8ead6 |
|
BLAKE2b-256 | 5518c0e8e0fd955271094bf7307dbec9daa7e1c96c9cb99978a7c4829226bbf6 |