Skip to main content

Train a neural network - with recursion!

Project description

recursive-sgd

A proof of concept of a recursion doing stochastic gradient descent done in Python

Why?

IDK I just hit me one day that, if I represent a neural network as a sequence of transformations, I could train it recursively.

How?

The idea is simple: to train some layer of a neural network, a recursive function trains the next layer and returns the gradient of that next layer. This gradient is then used to update the current layer and calculate the gradient w.r.t. the inputs of the current layer, which is then returned to a previous layer. The gist of it is in the sgd_step function of recursive_sgd/sgd.py.

Once again, why?

The answer is left as an exercise to the reader

Installation

Ordinary pip3 install recursive-sgd does the trick. Alternatively, one can use:

git clone https://github.com/InCogNiTo124/recursive-sgd.git
cd recursive-sgd
python3 setup.py install

Usage

There's a CLI available.

Training

python3 cli.py train [OPTIONS] where OPTIONS can be the following:

  • -d FILEPATH or --dataset FILEPATH - CSV dataset
  • -i INT or --input-size INT - the number of input features
  • --lr FLOAT - learning rate
  • --loss VALUE - Loss function (CE for CrossEntropy, MSE for MeanSquaredError)
  • -e INT or --epochs INT - the number of epochs
  • --batch-size INT - the size of one batch
  • --shuffle - shuffle dataset after every epoch (default)
  • --no-shuffle - never shuffle

The architecture is defined with arguments as well:

  • -l SIZE - a new layer with SIZE neurons
  • -b - add bias
  • -s - add Sigmoid activation
  • -r - add ReLU activation
  • -t - add Tanh activation

Checkout the example at train_command

Testing

python3 cli.py test [OPTIONS] where OPTIONS can be the following:

  • -m FILEPATH or --model FILEPATH - the path to the saved model
  • -d FILEPATH or --dataset FILEPATH - CSV dataset
  • -i INT or --input-size INT - the number of input features
  • --metrics VALUE - metric with witch you wish to test the model with.

Checkout the example at test_command

Notes

  • After training, the model will be saved in $PWD as MODEL.sgd. This is hardcoded for now, but will be configurable in the future.
  • There is no -h nor --help flag. I am parsing the arguments myself without any framework at all and I didn't bother writing help in the CLI but here.
  • There are serious limitations in the dataset loading:
    • Only CSV format is allowed
    • The columns MUST be separated by , character only.
    • The true labels column is implicitly the last one
      • Technically it's every remaining column after the number of features defined by -i or --input-size which may introduce subtle bugs of having more than 1 target variables.
  • Only available metric at the moment is accuracy.

One final note

https://twitter.com/johnwilander/status/1176457013305303040

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

recursive_sgd-0.5.tar.gz (4.9 kB view details)

Uploaded Source

Built Distribution

recursive_sgd-0.5-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file recursive_sgd-0.5.tar.gz.

File metadata

  • Download URL: recursive_sgd-0.5.tar.gz
  • Upload date:
  • Size: 4.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.0.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.5rc1

File hashes

Hashes for recursive_sgd-0.5.tar.gz
Algorithm Hash digest
SHA256 e526a2e1d6331fab4a3d37260ff0520357a1edb1a2656d9c1e27e7f69ec00ba2
MD5 695bb2df441e07166652e368ffe284e7
BLAKE2b-256 0a79a94b3eef0db6d970d4838eae30867f2bdc691d68364aef23f96f1e460e76

See more details on using hashes here.

File details

Details for the file recursive_sgd-0.5-py3-none-any.whl.

File metadata

  • Download URL: recursive_sgd-0.5-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.0.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.5rc1

File hashes

Hashes for recursive_sgd-0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 df285228a4c0817c2b73d3f755f3964955a23e4516709b8637ca154083cb8758
MD5 aaf57a3ef0b67734601b03e56ef82600
BLAKE2b-256 91c6df57d709f43c6dac05eea5427a6d43766c4cb761b894a3409a9d6a6fdf26

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page