Skip to main content

PyTorch implementation of DeepType with clustering and sparsity

Project description

torch-deeptype

PyTorch implementation of DeepType.

Installation

Run pip install torch-deeptype

Usage

Usage After installing (pip install torch-deeptype), follow these steps:

  1. Define your model Create a DeeptypeModel subclass that implements:

forward(self, x: Tensor) -> Tensor get_input_layer_weights(self) -> Tensor get_hidden_representations(self, x: Tensor) -> Tensor

Tip: Have forward() call get_hidden_representations() to avoid duplicating the hidden-layer code.

import torch
import torch.nn as nn
from torch_deeptype import DeeptypeModel

class MyNet(DeeptypeModel):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.input_layer   = nn.Linear(input_dim, hidden_dim)
        self.h1            = nn.Linear(hidden_dim, hidden_dim)
        self.cluster_layer = nn.Linear(hidden_dim, hidden_dim // 2)
        self.output_layer  = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Notice how forward() gets the hidden representations
        hidden = self.get_hidden_representations(x)
        return self.output_layer(hidden)

    def get_input_layer_weights(self) -> torch.Tensor:
        return self.input_layer.weight

    def get_hidden_representations(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.relu(self.input_layer(x))
        x = torch.relu(self.h1(x))
        x = torch.relu(self.cluster_layer(x))
        return x
  1. Prepare your data Wrap your tensors in a TensorDataset and DataLoader as usual:
from torch.utils.data import TensorDataset, DataLoader

# Example with random data:
X = torch.randn(1000, 20)         # 1000 samples, 20 features
y = torch.randint(0, 5, (1000,))  # 5 classes

dataset      = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
  1. Instantiate the trainer Use DeeptypeTrainer to set up both phases of DeepType training:
from torch_deeptype import DeeptypeTrainer

trainer = DeeptypeTrainer(
    model           = MyNet(input_dim=20, hidden_dim=64, output_dim=5),
    train_loader    = train_loader,
    primary_loss_fn = nn.CrossEntropyLoss(),
    num_clusters    = 8,       # K in KMeans
    sparsity_weight = 0.01,    # α for L₂ sparsity on input weights
    cluster_weight  = 0.5,     # β for cluster‐rep loss
    verbose         = True     # print per-epoch loss summaries
)
  1. Run training Call trainer.train(...) to execute the Deeptype training
trainer.train(
    main_epochs           = 15,     # epochs for joint phase
    main_lr               = 1e-4,   # LR for joint phase
    pretrain_epochs       = 10,     # epochs for pretrain phase
    pretrain_lr           = 1e-3,   # LR for pretrain (defaults to main_lr if None)
    train_steps_per_batch = 8       # inner updates per batch in joint phase
)

With verbose=True, you’ll see three loss components logged each epoch:

  • Primary (classification/regression loss)
  • Sparsity (input-weight L₂ penalty)
  • Cluster (hidden-representation vs. KMeans centers)
  1. Extract clusters and important inputs

After training, you can inspect:

  • KMeans clusters over your dataset’s hidden representations
  • Input‐feature importances via the L₂‐norm of each input weight column
from torch.utils.data import TensorDataset

# 1) Prepare the same dataset you trained on
dataset = TensorDataset(X, y)

# 2) Compute clusters
#    Returns:
#      - `centroids`: Tensor[num_clusters, hidden_dim]
#      - `labels`:    np.ndarray[N] of cluster assignments
centroids, labels = trainer.get_clusters(dataset)

print("Centroids shape:", centroids.shape)
print("Cluster assignments for first 10 samples:", labels[:10])


# 3) Compute input‐feature importance (on your model)
#    importance[i] = || W[:, i] ||₂ for first‐layer weights W
importances = trainer.model.get_input_importance()
print("Importances:", importances)

# 4) Get features sorted by importance
#    returns a Tensor of feature indices, most important first
sorted_idx = trainer.model.get_sorted_input_indices()
print("Top 5 features by importance:", sorted_idx[:5].tolist())

That’s all you need to get DeepType running end-to-end!

If you're a more advanced user, you can also use the SparsityLoss and ClusterRepresentationLoss directly.

Acknowledgements

This implementation is based on Runpu Chen's original implementation here. The original paper that introduced DeepType can be found here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_deeptype-0.1.0.tar.gz (11.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_deeptype-0.1.0-py3-none-any.whl (12.2 kB view details)

Uploaded Python 3

File details

Details for the file torch_deeptype-0.1.0.tar.gz.

File metadata

  • Download URL: torch_deeptype-0.1.0.tar.gz
  • Upload date:
  • Size: 11.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/75.3.2 requests-toolbelt/0.8.0 tqdm/4.30.0 CPython/3.8.10

File hashes

Hashes for torch_deeptype-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8da07100eb3dbaee50e947845feb13a3453bd0a5ef5233fd8a338dea872c0bf4
MD5 7ebecc8c3f2269f111f25024c183c93b
BLAKE2b-256 3e9e7dea6f53b7ec89c1b79c63dd26c8ab94cd125a1861062ccdc42a0467d8ff

See more details on using hashes here.

File details

Details for the file torch_deeptype-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torch_deeptype-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 12.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/75.3.2 requests-toolbelt/0.8.0 tqdm/4.30.0 CPython/3.8.10

File hashes

Hashes for torch_deeptype-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d93b69a0f6a63bedf4abc9d5e26230e0f9c105027e93c9e29ed83eadf4fb1408
MD5 76c7bad2ca5448ba62641ec432c1aa92
BLAKE2b-256 a6c96105df69ca15d05ade7676af1583efc74907ee2a5149cd58a9e599ee1f0a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page