Tools for train and prune in the dual space a fully connected layer.
Project description
spectraltools
spectraltools is a package for spectral training and analysis of fully connected feedforward NN.
According to our test it is well integrated in Tensorflow 2.10
and older versions up to Tensorflow 2.3
.
Installation
Activate the environment where the package is to be installed.
$ pip install spectraltools
Usage
Spectral layer
The package contains the spectral fully connected layer that can be imported as follows:
from spectraltools import Spectral
It is a representation in the reciprocal space of a fully connected layer.
The layer can be used inside a Tensorflow model and has three main attributes:
from TensorFlow.spectraltools import Spectral
Spectral(
units=300,
activation='relu',
is_base_trainable=True,
is_diag_end_trainable=True,
use_bias=False
)
In this configuration the layer is a fully connected layer with 300 nodes and ReLU activation function. The layer is equivalent to a dense
layer with a scalar parameter that multiplies the features. The layer is initialized with a random base and eigenvalues
equal to one (namely the initialization is equivalent to a fully-connected layer).
It implements the operation: output = activation(dot(input, spectral_kernel) + bias)
where
spectral_kernel = dot(base, diag_in) - dot(diag_out, base)
.
diag_in
and diag_out
are the eigenvalues of the adjacency matrix representing the layer and
base are the nontrivial components of its eigenvectors. bias
is a bias vector created by the layer
(only applicable if use_bias
is True
).
This configuration (where the eigenvectors and the diag_end are trained) is the one suggested and for which the
pruning function are developed. In the future other configurations support will be added.
Attributes description
If is_base_trainable=True
the eigenvectors of the adjacency matrix will be trained. This is equivalent to the
training of all the connections (features). Those are input_dim x output_dim
trainable parameters.
is_input_layer
(default set to False
) train the first input_dim
eigenvalues of the matrix and is_diag_end_trainable
trains the last
output_dim
eigenvalues. We recommend to set is_input_layer
to True
only for the first layer of the network and
leave it to False for the other layers. This is because the behaviour of the pruning algorithm has been tested and heuristically proven effective
only when in this setting.
The total number of trainable parameters is therefore input_dim x output_dim + input_dim + output_dim
.
If only the eigenvalues are trained the number of free parameters drops but the learning still occurs. A
suboptimal loss minimum is reached but overfitting is less likely to occur. If also eigenvectors are trained the layer
is, from a training point of view, the same as the Dense.
Spectral Pruning
The pruning function are tested with Functional or Sequential models implementing one or more Spectral layers. Best pruning results are achived when also an L2 regularization is applied to the spectral layer parameters: base and eigenvalues. There are two ways in which the pruning can be done:
- Percentile based Pruning: the pruning is done according to the eigenvalues distribution of every spectral layer in the model. The nodes with the smallest eigenvalues magnitude (according to the percentile given) are removed. The percentile of nodes to be removed is passed as an argument to the function. The compile configuration is needed It can be called as follows:
from TensorFlow.spectraltools import prune_percentile
pruned_model = prune_percentile(model,
percentile,
percentile_threshold)
model: Sequential
or Functional
model, employing one or more Spectral layers, that needs to be pruned.
percentile_threshold: the percentile (1-100) of nodes that the model should try to prune. The pruning is done by masking
the eigenvalues of the spectral layers which is equivalent to set all the corresponding features and biases to 0.
Example:
from tensorflow.keras.layers import Dense, Input
from TensorFlow.spectraltools import Spectral
inputs = Input(shape=(784,))
x = Dense(100,
activation='relu')(inputs)
x = Spectral(80,
activation='relu')(x)
In this case the prunable nodes will be 80.
The nodes are removed according to the eigenvalues
distribution which has been empirically and heuristically proven to be an indicator of node relevance inside the network.
- Metric based Pruning: the pruning is done according to the impact that the removal of a node has on the loss or another metric calculated on the dataset. The nodes with the smallest impact are removed. The impact is calculated by training the model on a validation set. that is given.
from TensorFlow.spectraltools import metric_based_pruning
pruned_model = metric_based_pruning(model,
eval_dictionary,
compile_dictionary,
compare_metric='accuracy',
max_delta_percent=10,
**kwargs)
model
: the trained model to be pruned.
eval_dictionary
: the dictionary with the arguments to be passed to the evaluate
method of the model.
compile_dictionary
: the dictionary with the arguments to be passed to the compile
method of the model.
max_delta_percent
: maximal variation of the given indicator at which break the pruning process.
compare_metric
: indicator to be used (the corresponding metric name should be used while compiling the model)
Contributing
Interested in contributing? Check out the contributing guidelines. Please note that this project is released with a Code of Conduct. By contributing to this project, you agree to abide by its terms.
License
spectraltools
was created by Lorenzo Giambagli. It is licensed under the terms of the MIT license.
Credits
spectraltools
was created with [cookiecutter
] (https://cookiecutter.readthedocs.io/en/latest/) and the py-pkgs-cookiecutter
template.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file spectraltools-1.2.0.tar.gz
.
File metadata
- Download URL: spectraltools-1.2.0.tar.gz
- Upload date:
- Size: 13.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ecade7261a8610ccf934adb277247cabcda68f6a9430f255b54e5fec2e779dcd |
|
MD5 | c37301557c65054f8075a8c5a68a03c8 |
|
BLAKE2b-256 | 97911a9c920c846b8f261566867616c8c516890a9c5ac808dbbf687fdbd35bd0 |
File details
Details for the file spectraltools-1.2.0-py3-none-any.whl
.
File metadata
- Download URL: spectraltools-1.2.0-py3-none-any.whl
- Upload date:
- Size: 11.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | aba547f6c077d45605cef68bad1f469571f2ae919d48ce8c24b73d789c98bdd0 |
|
MD5 | 168f13280a17c91471c3e7e7a67c29b1 |
|
BLAKE2b-256 | bda9a0e70e811106e1c83c1862f342b56e1f1be8f0e89cc65eb5dd980ad3dcdb |