Skip to main content

Python package for graph-based clustering and semi-supervised learning

Project description

Graph-based Clustering and Semi-Supervised Learning

Clustering

This python package is devoted to efficient implementations of modern graph-based learning algorithms for both semi-supervised learning and clustering. The package implements many popular datasets (currently MNIST, FashionMNIST, cifar-10, and WEBKB) in a way that makes it simple for users to test out new algorithms and rapidly compare against existing methods.

This package reproduces experiments from our paper

Calder, Cook, Thorpe, Slepcev. Poisson Learning: Graph Based Semi-Supervised Learning at Very Low Label Rates., Proceedings of the 37th International Conference on Machine Learning, PMLR 119:1306-1316, 2020.

Installation

Install with

pip install graphlearning

Wheels are provided for Windows, Mac and Linux. Required packages include numpy, scipy, sklearn, matplotlib, and torch. The packages annoy and kymatio are required for running nearest neighbor searches and the scattering transform, respectively, but the rest of the code will run fine without those packages.

To install from the github source, which is updated more frequently, run

git clone https://github.com/jwcalder/GraphLearning
cd GraphLearning
pip install -r requirements.txt
python setup.py install --user

Getting started with basic experiments

Below we outline some basic ways the package can be used. The examples page from our GitHub repository contains several detailed example scripts that are useful for getting started.

A basic experiment comparing Laplace learning/Label propagation to Poisson learning on MNIST can be run with

import graphlearning as gl
gl.ssl_trials(dataset='mnist',metric='vae',algorithm='laplace',k=10,t=10)
gl.ssl_trials(dataset='mnist',metric='vae',algorithm='laplace',k=10,t=10)

Supported datasets include MNIST, FashionMNIST, WEBKB, and cifar. The metric is used for constructing the graph, and can be 'raw' for all datasets, which is Euclidean distance between raw data, 'vae' for MNIST and FashionMNIST, which is the variational autoencoder weights as described in our paper, 'scatter', which uses the scattering transform, or 'aet' for cifar, which uses the AutoEncoding Transformations weights, also described in our paper. The 'k=10' specifies how many nearest neighbors to use in constructing the graph, and 't=10' specifies how many trials to run, randomly assigning training/testing data. There are many other optional arguments, and full documentation is coming soon.

Below is a list of currently supported algorithms with links to the corresponding papers.

Semi-supervised learning: Laplace, RandomWalk, Poisson, PoissonMBO, pLaplace, WNLL, ProperlyWeighted, NearestNeighbor, MBO, VolumeMBO, DynamicLabelPropagation, SparseLabelPropagation, CenteredKernel

Clustering: INCRES, Spectral, SpectralShiMalik, SpectralNgJordanWeiss

The algorithm names are case-insensitive in all scripts. NearestNeighbor chooses the label of the closest labeled node in the geodesic graph distance.

The accuracy scores are saved in the subdirectory Results/ using a separate .csv file for each experiment. These can be loaded to generate plots and tables (see the example scripts). The directory ResultsFromPaper/ contains all results from our ICML paper.

The commands shown above are rather high level, and can be split into several important subroutines when needed. The code below shows how to generate a weight matrix on the MNIST dataset, choose training data randomly, run Laplace and Poisson learning, and compute accuracy scores.

import graphlearning as gl

#Load labels, knndata, an build 10-nearest neighbor weight matrix
labels = gl.load_labels('mnist')
I,J,D = gl.load_kNN_data('mnist',metric='vae')
W = gl.weight_matrix(I,J,D,10)

#Randomly chose training datapoints
num_train_per_class = 1 
train_ind = gl.randomize_labels(labels, num_train_per_class)
train_labels = labels[train_ind]

#Run Laplace and Poisson learning
labels_laplace = gl.graph_ssl(W,train_ind,train_labels,algorithm='laplace')
labels_poisson = gl.graph_ssl(W,train_ind,train_labels,algorithm='poisson')

#Compute and print accuracy
print('Laplace learning: %.2f%%'%gl.accuracy(labels,labels_laplace,num_train_per_class))
print('Poisson learning: %.2f%%'%gl.accuracy(labels,labels_poisson,num_train_per_class))

Contact and questions

Email jwcalder@umn.edu with any questions or comments.

Acknowledgments

Several people have contributed to the development of this software:

  1. Mauricio Rios Flores (Machine Learning Researcher, Amazon)
  2. Brendan Cook (PhD Candidate in Mathematics, University of Minnesota)
  3. Matt Jacobs (Postdoc, UCLA)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graphlearning-0.0.1.tar.gz (48.4 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

graphlearning-0.0.1-cp38-cp38-win_amd64.whl (53.6 kB view details)

Uploaded CPython 3.8Windows x86-64

graphlearning-0.0.1-cp38-cp38-macosx_10_14_x86_64.whl (90.8 kB view details)

Uploaded CPython 3.8macOS 10.14+ x86-64

File details

Details for the file graphlearning-0.0.1.tar.gz.

File metadata

  • Download URL: graphlearning-0.0.1.tar.gz
  • Upload date:
  • Size: 48.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.22.0 setuptools/51.1.0 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.8.5

File hashes

Hashes for graphlearning-0.0.1.tar.gz
Algorithm Hash digest
SHA256 b315ad49fbaa78e447b0086b0b130428a9b809b708f6f573df9bf9df4a9af1db
MD5 16cc39cb73254b09833900d8c07e0bb5
BLAKE2b-256 3db992e9a2e86ff8b8c8d86784e8de520f68662f5d921a5ed111e265cc2f3736

See more details on using hashes here.

File details

Details for the file graphlearning-0.0.1-cp38-cp38-win_amd64.whl.

File metadata

  • Download URL: graphlearning-0.0.1-cp38-cp38-win_amd64.whl
  • Upload date:
  • Size: 53.6 kB
  • Tags: CPython 3.8, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.22.0 setuptools/51.1.0 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.8.5

File hashes

Hashes for graphlearning-0.0.1-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 4701e47b176b560bef493dc96d81354c2116fc32a331ffbfd25088b33d701c5c
MD5 e2e042e092f3aae9750cfba17500c460
BLAKE2b-256 4e0b3e3f7bdd7107d21e0f16b62194bf29dcf51c1e545253f030101c55337843

See more details on using hashes here.

File details

Details for the file graphlearning-0.0.1-cp38-cp38-macosx_10_14_x86_64.whl.

File metadata

  • Download URL: graphlearning-0.0.1-cp38-cp38-macosx_10_14_x86_64.whl
  • Upload date:
  • Size: 90.8 kB
  • Tags: CPython 3.8, macOS 10.14+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.22.0 setuptools/51.1.0 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.8.5

File hashes

Hashes for graphlearning-0.0.1-cp38-cp38-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 f75fbcd6756b222fa8cb6eb118b9248662d5cf1fbf5f404bf1e407bf17f22b05
MD5 78c58818a2e35c42958d6ad708837b8f
BLAKE2b-256 dfed97b739d4154159510739f659ea56f8355a0b358654fbb7f76e19cd88662e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page