Skip to main content

PyTorch implementation of DeepType with clustering and sparsity

Project description

torch-deeptype

PyTorch implementation of DeepType.

Installation

Run pip install torch-deeptype

Usage

Usage After installing (pip install torch-deeptype), follow these steps:

1. Define your model

Create a DeeptypeModel subclass that implements:

forward(self, x: Tensor) -> Tensor get_input_layer_weights(self) -> Tensor get_hidden_representations(self, x: Tensor) -> Tensor

Tip: Have forward() call get_hidden_representations() to avoid duplicating the hidden-layer code.

import torch
import torch.nn as nn
from torch_deeptype import DeeptypeModel

class MyNet(DeeptypeModel):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.input_layer   = nn.Linear(input_dim, hidden_dim)
        self.h1            = nn.Linear(hidden_dim, hidden_dim)
        self.cluster_layer = nn.Linear(hidden_dim, hidden_dim // 2)
        self.output_layer  = nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Notice how forward() gets the hidden representations
        hidden = self.get_hidden_representations(x)
        return self.output_layer(hidden)

    def get_input_layer_weights(self) -> torch.Tensor:
        return self.input_layer.weight

    def get_hidden_representations(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.relu(self.input_layer(x))
        x = torch.relu(self.h1(x))
        x = torch.relu(self.cluster_layer(x))
        return x

2. Prepare your data

Wrap your tensors in a TensorDataset and DataLoader as usual:

from torch.utils.data import TensorDataset, DataLoader

# Example with random data:
X = torch.randn(1000, 20)         # 1000 samples, 20 features
y = torch.randint(0, 5, (1000,))  # 5 classes

dataset      = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

3. Instantiate the trainer

Use DeeptypeTrainer to set up both phases of DeepType training:

from torch_deeptype import DeeptypeTrainer

trainer = DeeptypeTrainer(
    model           = MyNet(input_dim=20, hidden_dim=64, output_dim=5),
    train_loader    = train_loader,
    primary_loss_fn = nn.CrossEntropyLoss(),
    num_clusters    = 8,       # K in KMeans
    sparsity_weight = 0.01,    # α for L₂ sparsity on input weights
    cluster_weight  = 0.5,     # β for cluster‐rep loss
    verbose         = True     # print per-epoch loss summaries
)

4. Run training

Call trainer.train(...) to execute the Deeptype training

trainer.train(
    main_epochs           = 15,     # epochs for joint phase
    main_lr               = 1e-4,   # LR for joint phase
    pretrain_epochs       = 10,     # epochs for pretrain phase
    pretrain_lr           = 1e-3,   # LR for pretrain (defaults to main_lr if None)
    train_steps_per_batch = 8       # inner updates per batch in joint phase
)

With verbose=True, you’ll see three loss components logged each epoch:

  • Primary (classification/regression loss)
  • Sparsity (input-weight L₂ penalty)
  • Cluster (hidden-representation vs. KMeans centers)

5. Extract clusters and important inputs

After training, you can inspect:

  • KMeans clusters over your dataset’s hidden representations
  • Input‐feature importances via the L₂‐norm of each input weight column
from torch.utils.data import TensorDataset

# 1) Prepare the same dataset you trained on
dataset = TensorDataset(X, y)

# 2) Compute clusters
#    Returns:
#      - `centroids`: Tensor[num_clusters, hidden_dim]
#      - `labels`:    np.ndarray[N] of cluster assignments
centroids, labels = trainer.get_clusters(dataset)

print("Centroids shape:", centroids.shape)
print("Cluster assignments for first 10 samples:", labels[:10])


# 3) Compute input‐feature importance (on your model)
#    importance[i] = || W[:, i] ||₂ for first‐layer weights W
importances = trainer.model.get_input_importance()
print("Importances:", importances)

# 4) Get features sorted by importance
#    returns a Tensor of feature indices, most important first
sorted_idx = trainer.model.get_sorted_input_indices()
print("Top 5 features by importance:", sorted_idx[:5].tolist())

That’s all you need to get DeepType running end-to-end!

If you're a more advanced user, you can also use the SparsityLoss and ClusterRepresentationLoss directly.

Acknowledgements

This implementation is based on Runpu Chen's original implementation here. The original paper that introduced DeepType can be found here.

Check my article on the paper here

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_deeptype-0.1.1.tar.gz (12.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_deeptype-0.1.1-py3-none-any.whl (12.6 kB view details)

Uploaded Python 3

File details

Details for the file torch_deeptype-0.1.1.tar.gz.

File metadata

  • Download URL: torch_deeptype-0.1.1.tar.gz
  • Upload date:
  • Size: 12.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.2

File hashes

Hashes for torch_deeptype-0.1.1.tar.gz
Algorithm Hash digest
SHA256 3b8989ca162fd1bbbed18e87e7c36913a0193fb38009c918358e204d075241c2
MD5 313fd85555d0af2790879389b6033385
BLAKE2b-256 e17367fb0bda38b6e04d8bf17b8fe1826196649f4b8006e8c1f257438cbc1bad

See more details on using hashes here.

File details

Details for the file torch_deeptype-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: torch_deeptype-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 12.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.2

File hashes

Hashes for torch_deeptype-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 02ddd24fd26ae5d551c8c2d2237ce184e6b9cee9f8c468ecee6f83695f931e37
MD5 5fba209d914d41df5030bcde3a39e79e
BLAKE2b-256 55d4b253e1199432cc884717ede912c8dee9cfe36200ff8a757e4595571a542e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page