Skip to main content

General purpose Hessian-free optimization in Theano

Project description

Theano-hf-py3

This is a Python 3 verson of boulanni/theano-hf.

Install

pip install --upgrade theano-hf

Usage

import theano_hf

Original Description

I wrapped my Hessian-free code in a generic class, usable as a black-box to train your models if you can provide the cost function as a Theano expression.

It includes all the details in Martens (ICML 2010) and Martens & Sutskever (ICML 2011) crucial to make it work:

  • Tikhonov damping with the Levenberg-Marquardt heuristics,
  • Gauss-Newton matrix products (you specify an Theano expression s to section your computational graph in 2),
  • Proper handling of batches and mini-batches (an example SequenceDataset class is provided for variable-length input)
  • Conjugate gradient (CG) with information sharing, backtracking, preconditioning and terminations conditions.
  • Structural damping for RNNs.

It relies heavily on the Rop. In practice, I could make it work without hassle for a feed-forward network, an RNN with different objectives, NADE (Larochelle) and a more complex model (RNN-NADE) that ties two scans together, so it seems reasonably flexible. Only the gradients and Gauss-Newton matrix products (95% of the computation) are in Theano, CG and the training logic is in python. It runs on GPU, but for the models I tried, it was a bit slower. Hessian-free is slow, you need CG batch sizes of 1000+ (don't skimp on this), but you can get really better results than SGD from it with almost zero tweaking.

There is an option to save and recover a checkpoint of training and do early stopping.

I included an RNN example that can memorize an input for 100 time steps (example_RNN). Launch it on 4 cores, come back in 8 hours, and you should have at least one nice solution with 0 error on the validation set. In comparison, SGD can solve this problem about 0.0% of the time.

It is available here: https://github.com/boulanni/theano-hf

If you use this software for academic research, please cite the following paper:

[1] N. Boulanger-Lewandowski, Y. Bengio and P. Vincent, "Modeling Temporal Dependencies in High-Dimensional Sequences: Application to Polyphonic Music Generation and Transcription", Proc. ICML 29, 2012.

Author: Nicolas Boulanger-Lewandowski University of Montreal, 2012

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

theano-hf-0.2.1.tar.gz (7.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

theano_hf-0.2.1-py3-none-any.whl (18.6 kB view details)

Uploaded Python 3

File details

Details for the file theano-hf-0.2.1.tar.gz.

File metadata

  • Download URL: theano-hf-0.2.1.tar.gz
  • Upload date:
  • Size: 7.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.8.3

File hashes

Hashes for theano-hf-0.2.1.tar.gz
Algorithm Hash digest
SHA256 6acd13d61eecc05a5ed21ae1b5e8e425921aacd70a189a89038c9bbe33f03ca6
MD5 e8131d45d9c381b35b4b8469bdcb511b
BLAKE2b-256 7f579f2a2d8ec75f7d015400c5613cf1b1982d50180c379ae188f83a86b21e65

See more details on using hashes here.

File details

Details for the file theano_hf-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: theano_hf-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 18.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.8.3

File hashes

Hashes for theano_hf-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b82c043f888929a8c9d2f2a688e900ef0cdfae71ee175f0ffeb612e437b06dac
MD5 15677ed3366f0f652649eee8f0bcb97b
BLAKE2b-256 4842925d57f915f5e0c0166de6a5ca617287adceb4e652d592b9d0db8872261d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page