Train a neural network - with recursion!
Project description
recursive-sgd
A proof of concept of a recursion doing stochastic gradient descent done in Python
Why?
IDK I just hit me one day that, if I represent a neural network as a sequence of transformations, I could train it recursively.
How?
The idea is simple: to train some layer of a neural network, a recursive function trains the next layer and returns the gradient of that next layer. This gradient is then used to update the current layer and calculate the gradient w.r.t. the inputs of the current layer, which is then returned to a previous layer.
The gist of it is in the sgd_step function of recursive_sgd/sgd.py.
Once again, why?
The answer is left as an exercise to the reader
Installation
Ordinary pip3 install recursive-sgd does the trick.
Alternatively, one can use:
git clone https://github.com/InCogNiTo124/recursive-sgd.git
cd recursive-sgd
python3 setup.py install
Usage
There's a CLI available.
Training
python3 cli.py train [OPTIONS] where OPTIONS can be the following:
-d FILEPATHor--dataset FILEPATH- CSV dataset-i INTor--input-size INT- the number of input features--lr FLOAT- learning rate--loss VALUE- Loss function (CEfor CrossEntropy,MSEfor MeanSquaredError)-e INTor--epochs INT- the number of epochs--batch-size INT- the size of one batch--shuffle- shuffle dataset after every epoch (default)--no-shuffle- never shuffle
The architecture is defined with arguments as well:
-l SIZE- a new layer withSIZEneurons-b- add bias-s- add Sigmoid activation-r- add ReLU activation-t- add Tanh activation
Checkout the example at train_command
Testing
python3 cli.py test [OPTIONS] where OPTIONS can be the following:
-m FILEPATHor--model FILEPATH- the path to the saved model-d FILEPATHor--dataset FILEPATH- CSV dataset-i INTor--input-size INT- the number of input features--metrics VALUE- metric with witch you wish to test the model with.
Checkout the example at test_command
Notes
- After training, the model will be saved in
$PWDasMODEL.sgd. This is hardcoded for now, but will be configurable in the future. - There is no
-hnor--helpflag. I am parsing the arguments myself without any framework at all and I didn't bother writing help in the CLI but here. - There are serious limitations in the dataset loading:
- Only CSV format is allowed
- The columns MUST be separated by
,character only. - The true labels column is implicitly the last one
- Technically it's every remaining column after the number of features defined by
-ior--input-sizewhich may introduce subtle bugs of having more than 1 target variables.
- Technically it's every remaining column after the number of features defined by
- Only available metric at the moment is
accuracy.
One final note
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file recursive_sgd-0.5.tar.gz.
File metadata
- Download URL: recursive_sgd-0.5.tar.gz
- Upload date:
- Size: 4.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.0.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.5rc1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e526a2e1d6331fab4a3d37260ff0520357a1edb1a2656d9c1e27e7f69ec00ba2
|
|
| MD5 |
695bb2df441e07166652e368ffe284e7
|
|
| BLAKE2b-256 |
0a79a94b3eef0db6d970d4838eae30867f2bdc691d68364aef23f96f1e460e76
|
File details
Details for the file recursive_sgd-0.5-py3-none-any.whl.
File metadata
- Download URL: recursive_sgd-0.5-py3-none-any.whl
- Upload date:
- Size: 7.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.0.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.5rc1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
df285228a4c0817c2b73d3f755f3964955a23e4516709b8637ca154083cb8758
|
|
| MD5 |
aaf57a3ef0b67734601b03e56ef82600
|
|
| BLAKE2b-256 |
91c6df57d709f43c6dac05eea5427a6d43766c4cb761b894a3409a9d6a6fdf26
|