Skip to main content

Implements adaptive computation time RNNs in PyTorch, with the same interface as builtin RNNs.

Project description

pytorch-adaptive-computation-time

This library implements PyTorch modules for recurrent neural networks that can learn to execute variable-time algorithms, as presented in Adaptive Computation Time for Recurrent Neural Networks (Graves 2016). These models can learn patterns requiring varying amounts of computation for a fixed-size input, which is difficult or impossible for traditional neural networks. The library aims to be clean, idiomatic, and extensible, offering a similar interface to PyTorch’s builtin recurrent modules.

The main features are:

  • A nearly drop-in replacement for torch.nn.RNN- and torch.nn.RNNCell-style RNNs, but with the power of variable computation time.
  • A wrapper which adds adaptive computation time to any RNNCell.
  • Data generators, configs, and training scripts to reproduce experiments from the paper.

Example

Vanilla PyTorch GRU:

rnn = torch.nn.GRU(64, 128, num_layers=2)
output, hidden = rnn(inputs, initial_hidden)

GRU with adaptive computation time:

rnn = models.AdaptiveGRU(64, 128, num_layers=2, time_penalty=1e-3)
output, hidden, ponder_cost = rnn(inputs, initial_hidden)

Documentation

Documentation is hosted on Read the Docs.

BibTeX

You don’t need to cite this code, but if it helps you in your research and you’d like to:

@misc{swope2020ACT,
  title   = "pytorch-adaptive-computation-time",
  author  = "Swope, Aidan",
  journal = "GitHub",
  year    = "2020",
  url     = "https://github.com/maxwells-daemons/pytorch-adaptive-computation-time"
}

If you use the experiment code, please also consider citing PyTorch Lightning.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-adaptive-computation-time-0.1.3.tar.gz (14.2 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file pytorch-adaptive-computation-time-0.1.3.tar.gz.

File metadata

File hashes

Hashes for pytorch-adaptive-computation-time-0.1.3.tar.gz
Algorithm Hash digest
SHA256 bf5744326e623310db0c048824ff0091d8efd519ed426e64bdae81248f7ba7c3
MD5 0a078dafe735a57582adf66e49f66e94
BLAKE2b-256 28ec36041fb56a08079f6d87bc5f82b6246d74405d7c7f02ec08dd378dab718f

See more details on using hashes here.

File details

Details for the file pytorch_adaptive_computation_time-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_adaptive_computation_time-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 85188c0e1cf284772065a2d515db0573f03af2da62268bdc6636f8cf586ff0c3
MD5 f814ce644b466b1da921a330aeb8ead6
BLAKE2b-256 5518c0e8e0fd955271094bf7307dbec9daa7e1c96c9cb99978a7c4829226bbf6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page