A Hierarchical Softmax Framework for PyTorch.
Project description
A Hierarchical Softmax Framework for PyTorch.
Installation
hierarchicalsoftmax can be installed from PyPI:
pip install hierarchicalsoftmax
Alternatively, hierarchicalsoftmax can be installed using pip from the git repository:
pip install git+https://github.com/rbturnbull/hierarchicalsoftmax.git
Usage
Build up a hierarchy tree for your categories using the SoftmaxNode instances:
from hierarchicalsoftmax import SoftmaxNode
root = SoftmaxNode("root")
a = SoftmaxNode("a", parent=root)
aa = SoftmaxNode("aa", parent=a)
ab = SoftmaxNode("ab", parent=a)
b = SoftmaxNode("b", parent=root)
ba = SoftmaxNode("ba", parent=b)
bb = SoftmaxNode("bb", parent=b)
The SoftmaxNode class inherits from the anytree Node class which means that you can use methods from that library to build and interact with your hierarchy tree.
The tree can be rendered as a string with the render method:
root.render(print=True)
This results in a text representation of the tree:
root ├── a │ ├── aa │ └── ab └── b ├── ba └── bb
The tree can also be rendered to a file using graphviz if it is installed:
root.render(filepath="tree.svg")
Then you can add a final layer to your network that has the right size of outputs for the softmax layers. You can do that manually by setting the output number of features to root.layer_size. Alternatively you can use the HierarchicalSoftmaxLinear or HierarchicalSoftmaxLazyLinear classes:
from torch import nn
from hierarchicalsoftmax import HierarchicalSoftmaxLinear
model = nn.Sequential(
nn.Linear(in_features=20, out_features=100),
nn.ReLU(),
HierarchicalSoftmaxLinear(in_features=100, root=root)
)
Once you have the hierarchy tree, then you can use the HierarchicalSoftmaxLoss module:
from hierarchicalsoftmax import HierarchicalSoftmaxLoss
loss = HierarchicalSoftmaxLoss(root=root)
Metric functions are provided to show accuracy and the F1 score:
from hierarchicalsoftmax import greedy_accuracy, greedy_f1_score
accuracy = greedy_accuracy(predictions, targets, root=root)
f1 = greedy_f1_score(predictions, targets, root=root)
The nodes predicted from the final layer of the model can be inferred using the greedy_predictions function which provides a list of the predicted nodes:
from hierarchicalsoftmax import greedy_predictions
outputs = model(inputs)
inferred_nodes = greedy_predictions(outputs)
Relative contributions to the loss
The loss for each node can be weighted relative to each other by setting the alpha value for each parent node. By default the alpha value of a node is 1.
For example, the loss for the first level of classification (under the root node) will contribute twice as much to the loss than under the a or b nodes.
from hierarchicalsoftmax import SoftmaxNode
root = SoftmaxNode("root", alpha=2.0)
a = SoftmaxNode("a", parent=root)
aa = SoftmaxNode("aa", parent=a)
ab = SoftmaxNode("ab", parent=a)
b = SoftmaxNode("b", parent=root)
ba = SoftmaxNode("ba", parent=b)
bb = SoftmaxNode("bb", parent=b)
Label Smoothing
You can add label smoothing to the loss by setting the label_smoothing parameter to any of the nodes.
Focal Loss
You can use the Focal Loss instead of a basic cross-entropy loss for any of the nodes by setting the gamma parameter to any of the nodes.
Credits
Robert Turnbull <robert.turnbull@unimelb.edu.au>
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file hierarchicalsoftmax-1.0.2.tar.gz
.
File metadata
- Download URL: hierarchicalsoftmax-1.0.2.tar.gz
- Upload date:
- Size: 15.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.11.4 Darwin/22.1.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a894dffdc7ec366d6c4846badcb3e21bf9b09b9eed247f4f116b310231195135 |
|
MD5 | 4899890eedce5ea978f5f84bfe92847e |
|
BLAKE2b-256 | 28f3c39d040184ac01a53976ca272b4cc601f3e1177017d2afbcd8e0c253f323 |
File details
Details for the file hierarchicalsoftmax-1.0.2-py3-none-any.whl
.
File metadata
- Download URL: hierarchicalsoftmax-1.0.2-py3-none-any.whl
- Upload date:
- Size: 16.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.11.4 Darwin/22.1.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 63cbbd76bc80f9980ce90e76394101850d780cdd7ddcaef5636fd12f5d829da0 |
|
MD5 | c9d0595da838b2b321a0ec9b6d476df1 |
|
BLAKE2b-256 | 94585bfef43e4e00cafa01b0e483aa500d43d9f38b203b917b7e83b1d5b4a2ce |