Spectral Clustering Using Deep Neural Networks
Project description
SpectralNet
SpectralNet is a Python package that performs spectral clustering with deep neural networks.
This package is based on the following paper - SpectralNet
Installation
You can install the latest package version via
pip install spectralnet
Usage
Clustering
The basic functionality is quite intuitive and easy to use, e.g.,
from spectralnet import SpectralNet
spectralnet = SpectralNet(n_clusters=10)
spectralnet.fit(X) # X is the dataset and it should be a torch.Tensor
cluster_assignments = spectralnet.predict(X) # Get the final assignments to clusters
If you have labels to your dataset and you want to measure ACC and NMI you can do the following:
from spectralnet import SpectralNet
from spectralnet import Metrics
spectralnet = SpectralNet(n_clusters=2)
spectralnet.fit(X, y) # X is the dataset and it should be a torch.Tensor
cluster_assignments = spectralnet.predict(X) # Get the final assignments to clusters
y = y_train.detach().cpu().numpy() # In case your labels are of torch.Tensor type.
acc_score = Metrics.acc_score(cluster_assignments, y, n_clusters=2)
nmi_score = Metrics.nmi_score(cluster_assignments, y)
print(f"ACC: {np.round(acc_score, 3)}")
print(f"NMI: {np.round(nmi_score, 3)}")
You can read the code docs for more information and functionalities
Running examples
In order to run the model on twomoons or MNIST datasets, you should first cd to the examples folder and then run:
python3 cluster_twomoons.py
or
python3 cluster_mnist.py
Citation
@inproceedings{shaham2018,
author = {Uri Shaham and Kelly Stanton and Henri Li and Boaz Nadler and Ronen Basri and Yuval Kluger},
title = {SpectralNet: Spectral Clustering Using Deep Neural Networks},
booktitle = {Proc. ICLR 2018},
year = {2018}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for spectralnet-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 94ca367de6a78ca47cb45d011eb126b26be3d4208a9bc063b1113ab5756b52ad |
|
MD5 | 08155d70105f84ed9930dd19c136ae07 |
|
BLAKE2b-256 | 6cc9fc93144d1e292559c6089fa74df8885d0202d10871eeca76aba1bcddd6e6 |