Skip to main content

A Hierarchical Softmax Framework for PyTorch.

Project description

testing badge coverage badge docs badge black badge

A Hierarchical Softmax Framework for PyTorch.

Installation

hierarchicalsoftmax can be installed from PyPI:

pip install hierarchicalsoftmax

Alternatively, hierarchicalsoftmax can be installed using pip from the git repository:

pip install git+https://github.com/rbturnbull/hierarchicalsoftmax.git

Usage

Build up a hierarchy tree for your categories using the SoftmaxNode instances:

from hierarchicalsoftmax import SoftmaxNode

root = SoftmaxNode("root")
a = SoftmaxNode("a", parent=root)
aa = SoftmaxNode("aa", parent=a)
ab = SoftmaxNode("ab", parent=a)
b = SoftmaxNode("b", parent=root)
ba = SoftmaxNode("ba", parent=b)
bb = SoftmaxNode("bb", parent=b)

The SoftmaxNode class inherits from the anytree Node class which means that you can use methods from that library to build and interact with your hierarchy tree.

The tree can be rendered as a string with the render method:

root.render(print=True)

This results in a text representation of the tree:

root
├── a
│   ├── aa
│   └── ab
└── b
    ├── ba
    └── bb

The tree can also be rendered to a file using graphviz if it is installed:

root.render(filepath="tree.svg")
An example tree rendering.

Then you can add a final layer to your network that has the right size of outputs for the softmax layers. You can do that manually by setting the output number of features to root.layer_size. Alternatively you can use the HierarchicalSoftmaxLinear or HierarchicalSoftmaxLazyLinear classes:

from torch import nn
from hierarchicalsoftmax import HierarchicalSoftmaxLinear

model = nn.Sequential(
    nn.Linear(in_features=20, out_features=100),
    nn.ReLU(),
    HierarchicalSoftmaxLinear(in_features=100, root=root)
)

Once you have the hierarchy tree, then you can use the HierarchicalSoftmaxLoss module:

from hierarchicalsoftmax import HierarchicalSoftmaxLoss

loss = HierarchicalSoftmaxLoss(root=root)

Metric functions are provided to show accuracy and the F1 score:

from hierarchicalsoftmax import greedy_accuracy, greedy_f1_score

accuracy = greedy_accuracy(predictions, targets, root=root)
f1 = greedy_f1_score(predictions, targets, root=root)

The nodes predicted from the final layer of the model can be inferred using the greedy_predictions function which provides a list of the predicted nodes:

from hierarchicalsoftmax import greedy_predictions

outputs = model(inputs)
inferred_nodes = greedy_predictions(outputs)

Relative contributions to the loss

The loss for each node can be weighted relative to each other by setting the alpha value for each parent node. By default the alpha value of a node is 1.

For example, the loss for the first level of classification (under the root node) will contribute twice as much to the loss than under the a or b nodes.

from hierarchicalsoftmax import SoftmaxNode

root = SoftmaxNode("root", alpha=2.0)
a = SoftmaxNode("a", parent=root)
aa = SoftmaxNode("aa", parent=a)
ab = SoftmaxNode("ab", parent=a)
b = SoftmaxNode("b", parent=root)
ba = SoftmaxNode("ba", parent=b)
bb = SoftmaxNode("bb", parent=b)

Label Smoothing

You can add label smoothing to the loss by setting the label_smoothing parameter to any of the nodes.

Focal Loss

You can use the Focal Loss instead of a basic cross-entropy loss for any of the nodes by setting the gamma parameter to any of the nodes.

Credits

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hierarchicalsoftmax-1.0.2.tar.gz (15.0 kB view details)

Uploaded Source

Built Distribution

hierarchicalsoftmax-1.0.2-py3-none-any.whl (16.1 kB view details)

Uploaded Python 3

File details

Details for the file hierarchicalsoftmax-1.0.2.tar.gz.

File metadata

  • Download URL: hierarchicalsoftmax-1.0.2.tar.gz
  • Upload date:
  • Size: 15.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.11.4 Darwin/22.1.0

File hashes

Hashes for hierarchicalsoftmax-1.0.2.tar.gz
Algorithm Hash digest
SHA256 a894dffdc7ec366d6c4846badcb3e21bf9b09b9eed247f4f116b310231195135
MD5 4899890eedce5ea978f5f84bfe92847e
BLAKE2b-256 28f3c39d040184ac01a53976ca272b4cc601f3e1177017d2afbcd8e0c253f323

See more details on using hashes here.

File details

Details for the file hierarchicalsoftmax-1.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for hierarchicalsoftmax-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 63cbbd76bc80f9980ce90e76394101850d780cdd7ddcaef5636fd12f5d829da0
MD5 c9d0595da838b2b321a0ec9b6d476df1
BLAKE2b-256 94585bfef43e4e00cafa01b0e483aa500d43d9f38b203b917b7e83b1d5b4a2ce

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page