Skip to main content

simple hierarchal models in pytorch

Project description

Hierarchal Classification Networks

When looking at the task for classifying something where hierarchies were intrinsic to the classes, I searched for any libraries that might do very simple classification using grouped classes with hierarchies. However, I did not find any libraries that were suited for this relatively simple task. So I sought to create a more general solution that others can hopefully benefit from.

The concept is quite simple: create general architecture for groupings of classes dependent on each other. So starting with a basic concept of model, I looked to make something in PyTorch that represented my idea.

Example

Let us take an image geolocation problem where we want the location for city, county, and district. We will call these groupings a,b,c respectively. Given an image input, we want to predict all 3 classes but also need an architecture in which these relationships are properly represented. The network architecture below illustrates a possible solution (that this package will attempt to implement with a degree of adaptability). The architecture can be visualized as so with an input image: Network Architecture

where the class hierarchy is like so

Class Heirarchy

The class hierarchy is a similar structure to an example within this package. Each node has a tuple of a named grouping and the number of classes within that grouping. This the reason for the sizes in the final outputs in the network architecture. The large green plus signs within circles are used to indicate concatenation of the two input (green arrowed lines) leading into them. This is why the sections for class b and c have input size 4096 + 1024 = 5120.

Installation

First, install PyTorch. You can use requirements.txt to install PyTorch 1.7, however, the best way to install is to go to PyTorch's website and follow the instructions there. This package may work with versions less than 1.7, but it was only tested on PyTorch 1.7. This package will allow for versions of PyTorch >= 1.0, but please know only 1.7 is tested. Using pip makes this installation easy and simple once PyTorch is installed. This can be installed through

pip install simple-hierarchy-pytorch

The repository can also be cloned and then built from source. This can be done like so:

git clone https://github.com/rajivsarvepalli/SimpleHierarchy.git
cd SimpleHierarchy
python setup.py sdist bdist_wheel
pip install dist/simple_hierarchy_pytorch-0.0.1-py3-none-any.whl

Finally, this repository can simply be downloaded and imported as python code since there are essentially only two required classes here.

Getting Started

This architecture allows for simple yet adaptable hierarchal classifications for basic tasks that involve finite hierarchies. The package was targeted towards image classifications where there are multiple groups to classify something as but may serve other purposes equally well. Below is an example of how to use the package along with the defined class:

from simple_hierarchy.hierarchal_model import HierarchalModel 
hierarchy = {
    ('A', 2) : [('B', 5)],
    ('B', 5) : [('C', 7)]
}
model_base = nn.Sequential(
  nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5), 
  nn.ReLU(), 
  nn.MaxPool2d(kernel_size=2, stride=2), 
  nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), 
  nn.ReLU(), 
  nn.MaxPool2d(kernel_size=2, stride=2), 
  nn.Flatten(start_dim=1), 
  nn.Linear(in_features=1296, out_features=120), 
  nn.ReLU(), 
  nn.Linear(in_features=120, out_features=84), 
  nn.ReLU()
)
model = HierarchalModel(hierarchy, (84, 32, 32),base_model=model_base)
# Example input 
a = torch.rand(3,50,50).unsqueeze(0)
model(a)

Then the model can be trained on an image dataset like any other model.

Additionally, there is jupyter notebook within this repository illustrates some examples of how to use and run these classes. Most of the jupyter notebook is self-contained so all the necessary code is already inside there independent of the package. At the beginning is an example with this package as well. The formulation is quite simple, so it should not be too much additional work to incorporate the HierarchalModel into your networks. However, the solution given here is quite simple and therefore can be implemented easily for specific cases. The HierarchalModel class just provides a general solution for more use cases and gave me chance to test and build some architectural ideas.

Authors

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

simple-hierarchy-pytorch-0.0.1.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

simple_hierarchy_pytorch-0.0.1-py3-none-any.whl (8.9 kB view details)

Uploaded Python 3

File details

Details for the file simple-hierarchy-pytorch-0.0.1.tar.gz.

File metadata

  • Download URL: simple-hierarchy-pytorch-0.0.1.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.22.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.4

File hashes

Hashes for simple-hierarchy-pytorch-0.0.1.tar.gz
Algorithm Hash digest
SHA256 fad0bc1d3711620add2e1504355e014e0eb4ef8893c8de1cc4818b7390832c2c
MD5 482a18fc50d5114803807101d0d4faf0
BLAKE2b-256 513a95879e33b81ee77bcadceff24acf1a9f6bd43219cb081add0653a1b360fd

See more details on using hashes here.

File details

Details for the file simple_hierarchy_pytorch-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: simple_hierarchy_pytorch-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 8.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.22.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.4

File hashes

Hashes for simple_hierarchy_pytorch-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d763f9580df0451c3f7973822fa3a789ad18af552ecb9d38ebf359bf8532fccd
MD5 c26cce1efdcc7240719657743aade834
BLAKE2b-256 9b45637f5b1999208c1713a344c30a6e44af2e104575b96c97ae1558679cbd51

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page