Skip to main content

General purpose Hessian-free optimization in Theano

Project description

Theano-hf-py3

This is a Python 3 verson of boulanni/theano-hf.

Original Description

I wrapped my Hessian-free code in a generic class, usable as a black-box to train your models if you can provide the cost function as a Theano expression.

It includes all the details in Martens (ICML 2010) and Martens & Sutskever (ICML 2011) crucial to make it work:

  • Tikhonov damping with the Levenberg-Marquardt heuristics,
  • Gauss-Newton matrix products (you specify an Theano expression s to section your computational graph in 2),
  • Proper handling of batches and mini-batches (an example SequenceDataset class is provided for variable-length input)
  • Conjugate gradient (CG) with information sharing, backtracking, preconditioning and terminations conditions.
  • Structural damping for RNNs.

It relies heavily on the Rop. In practice, I could make it work without hassle for a feed-forward network, an RNN with different objectives, NADE (Larochelle) and a more complex model (RNN-NADE) that ties two scans together, so it seems reasonably flexible. Only the gradients and Gauss-Newton matrix products (95% of the computation) are in Theano, CG and the training logic is in python. It runs on GPU, but for the models I tried, it was a bit slower. Hessian-free is slow, you need CG batch sizes of 1000+ (don't skimp on this), but you can get really better results than SGD from it with almost zero tweaking.

There is an option to save and recover a checkpoint of training and do early stopping.

I included an RNN example that can memorize an input for 100 time steps (example_RNN). Launch it on 4 cores, come back in 8 hours, and you should have at least one nice solution with 0 error on the validation set. In comparison, SGD can solve this problem about 0.0% of the time.

It is available here: https://github.com/boulanni/theano-hf

If you use this software for academic research, please cite the following paper:

[1] N. Boulanger-Lewandowski, Y. Bengio and P. Vincent, "Modeling Temporal Dependencies in High-Dimensional Sequences: Application to Polyphonic Music Generation and Transcription", Proc. ICML 29, 2012.

Author: Nicolas Boulanger-Lewandowski University of Montreal, 2012

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

theano-hf-0.2.0.tar.gz (7.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

theano_hf-0.2.0-py3-none-any.whl (18.5 kB view details)

Uploaded Python 3

File details

Details for the file theano-hf-0.2.0.tar.gz.

File metadata

  • Download URL: theano-hf-0.2.0.tar.gz
  • Upload date:
  • Size: 7.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.8.3

File hashes

Hashes for theano-hf-0.2.0.tar.gz
Algorithm Hash digest
SHA256 fc1a31a1a50926b06ad31201de1b9dca38dd2036b03eef1480dbffec2140ab3c
MD5 ce6e0b575e8adf95df742adae6f9e3bd
BLAKE2b-256 a4a8fd144245d3fabf3e838709b786b65ef6b7c6a950d1e20be3f5a0be06ab04

See more details on using hashes here.

File details

Details for the file theano_hf-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: theano_hf-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 18.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.8.3

File hashes

Hashes for theano_hf-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 33fb0a3dd300050655bb8db6f4c7411b820e5f77375b4adaa2e85239515c8dea
MD5 3a3a92166e7ff7c60ab5c3932ad161a2
BLAKE2b-256 9a8a75e14f2c547087440df5c89d00d80b00b201cef38d2cacb4ec7c8e4eb363

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page