Skip to main content

m-phate

Project description

M-PHATE

Latest PyPi version Travis CI Build Coverage Status arXiv Preprint Twitter GitHub stars

Demonstration M-PHATE plot

Multislice PHATE (M-PHATE) is a dimensionality reduction algorithm for the visualization of time-evolving data. To learn more about M-PHATE, you can read our preprint on arXiv in which we apply it to the evolution of neural networks over the course of training. Above we show a demonstration of M-PHATE applied to a 3-layer MLP over 300 epochs of training, colored by epoch (left), hidden layer (center) and the digit label that most strongly activates each hidden unit (right). Below, you see the same network with dropout applied in training embedded in 3D, also colored by most active unit.

3D rotating gif

Table of Contents

How it works

Multislice PHATE (M-PHATE) combines a novel multislice kernel construction with the PHATE visualization. Our kernel captures the dynamics of an evolving graph structure, that when when visualized, gives unique intuition about the evolution of a system; in our preprint on arXiv, we show this applied to a neural network over the course of training and re-training. We compare M-PHATE to other dimensionality reduction techniques, showing that the combined construction of the multislice kernel and the use of PHATE provide significant improvements to visualization. In two vignettes, we demonstrate the use M-PHATE on established training tasks and learning methods in continual learning, and in regularization techniques commonly used to improve generalization performance.

The multislice kernel used in M-PHATE consists of building graphs over time slices of data (e.g. epochs in neural network training) and then connecting these slices by connecting each point to itself over time, weighted by its similarity. The result is a highly sparse, structured kernel which provides insight into the evolving structure of the data.

Example of multislice graph

Example of multislice kernel

Installation

Install from pypi

pip install --user m-phate

Install from source

pip install --user git+https://github.com/scottgigante/m-phate.git

Usage

Basic usage example

Below we apply M-PHATE to simulated data of 50 points undergoing random motion.

import numpy as np
import m_phate
import scprep

# create fake data
n_time_steps = 100
n_points = 50
n_dim = 25
np.random.seed(42)
data = np.cumsum(np.random.normal(0, 1, (n_time_steps, n_points, n_dim)), axis=0)

# embedding
m_phate_op = m_phate.M_PHATE()
m_phate_data = m_phate_op.fit_transform(data)

# plot
time = np.repeat(np.arange(n_time_steps), n_points)
scprep.plot.scatter2d(m_phate_data, c=time, ticks=False, label_prefix="M-PHATE")

Example embedding

Network training

To apply M-PHATE to neural networks, we provide helper classes to store the samples from the network during training. In order to use these, you must install tensorflow and keras.

import numpy as np

import keras
import scprep

import m_phate
import m_phate.train
import m_phate.data

# load data
x_train, x_test, y_train, y_test = m_phate.data.load_mnist()

# select trace examples
trace_idx = [np.random.choice(np.argwhere(y_test[:, i] == 1).flatten(),
                              10, replace=False)
             for i in range(10)]
trace_data = x_test[np.concatenate(trace_idx)]

# build neural network
lrelu = keras.layers.LeakyReLU(alpha=0.1)
inputs = keras.layers.Input(
    shape=(x_train.shape[1],), dtype='float32', name='inputs')
h1 = keras.layers.Dense(128, activation=lrelu, name='h1')(inputs)
h2 = keras.layers.Dense(128, activation=lrelu, name='h2')(h1)
h3 = keras.layers.Dense(128, activation=lrelu, name='h3')(h2)
outputs = keras.layers.Dense(10, activation='softmax', name='output_all')(h3)

# build trace model helper
model_trace = keras.models.Model(inputs=inputs, outputs=[h1, h2, h3])
trace = m_phate.train.TraceHistory(trace_data, model_trace)

# compile network
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='categorical_crossentropy',
              metrics=['categorical_accuracy', 'categorical_crossentropy'])

# train network
model.fit(x_train, y_train, batch_size=128, epochs=200,
          verbose=1, callbacks=[trace],
          validation_data=(x_test,
                           y_test))

# extract trace data
trace_data = np.array(trace.trace)
epoch = np.repeat(np.arange(trace_data.shape[0]), trace_data.shape[1])

# apply M-PHATE
m_phate_op = m_phate.M_PHATE()
m_phate_data = m_phate_op.fit_transform(trace_data)

# plot the result
scprep.plot.scatter2d(m_phate_data, c=epoch, ticks=False,
                      label_prefix="M-PHATE")

For detailed examples, see our sample notebooks in keras and tensorflow in examples:

Parameter tuning

The key to tuning the parameters of M-PHATE is essentially balancing the tradeoff between interslice connectivity and intraslice connectivity. This is primarily achieved with interslice_knn and intraslice_knn. You can see an example of the effects of parameter tuning in this notebook.

Figure reproduction

We provide scripts to reproduce all of the empirical figures in the preprint.

To run them:

git clone https://github.com/scottgigante/m-phate
cd m-phate
pip install --user .
DATA_DIR=~/data/checkpoints/m_phate # change this if you want to store the data elsewhere

chmod +x scripts/generalization/generalization_train.sh
chmod +x scripts/task_switching/classifier_mnist_task_switch_train.sh

./scripts/generalization/generalization_train.sh $DATA_DIR
./scripts/task_switching/classifier_mnist_task_switch_train.sh $DATA_DIR

python scripts/demonstration_plot.py $DATA_DIR
python scripts/comparison_plot.py $DATA_DIR
python scripts/generalization_plot.py $DATA_DIR
python scripts/task_switch_plot.py $DATA_DIR

# generalization plot using training data
./scripts/generalization/generalization_train.sh ${DATA_DIR}/train_data --sample-train-data
mkdir train_data; cd train_data; python -i ../scripts/generalization_plot.py ${DATA_DIR}/train_data; cd ..

TODO

  • Provide support for PyTorch
  • Notebook examples for:
    • Classification, pytorch
    • Autoencoder, pytorch
  • Build readthedocs page
  • Update arXiv link

Help

If you have any questions, please feel free to open an issue.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

m_phate-0.1.0.tar.gz (15.1 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

m_phate-0.1.0-py3.7.egg (20.4 kB view details)

Uploaded Egg

m_phate-0.1.0-py3.6.egg (20.4 kB view details)

Uploaded Egg

m_phate-0.1.0-py3.5.egg (20.5 kB view details)

Uploaded Egg

m_phate-0.1.0-py3-none-any.whl (23.6 kB view details)

Uploaded Python 3

File details

Details for the file m_phate-0.1.0.tar.gz.

File metadata

  • Download URL: m_phate-0.1.0.tar.gz
  • Upload date:
  • Size: 15.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.5.6

File hashes

Hashes for m_phate-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b5c7f1e47dc6507a502d89a39091fe9c910f3a06a287cabd31b41a68eac01698
MD5 e11b508291a489ccf262bfe22d39f529
BLAKE2b-256 e3a6b5dfd332ea0a5cfc90a89b58164ffe2704a98ff18ee0a61b3543b511604d

See more details on using hashes here.

File details

Details for the file m_phate-0.1.0-py3.7.egg.

File metadata

  • Download URL: m_phate-0.1.0-py3.7.egg
  • Upload date:
  • Size: 20.4 kB
  • Tags: Egg
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.7.4+

File hashes

Hashes for m_phate-0.1.0-py3.7.egg
Algorithm Hash digest
SHA256 77118aa1bd27ba92db4deb0eccd3694b73f0cb9dc3974ee069a0aa92dfedd961
MD5 b1d2adfff5aede832496e89fb8d633b4
BLAKE2b-256 d9994d0ad4bca46723838d442f596ca97d0d82649464ac80e12cdfa1f952eadb

See more details on using hashes here.

File details

Details for the file m_phate-0.1.0-py3.6.egg.

File metadata

  • Download URL: m_phate-0.1.0-py3.6.egg
  • Upload date:
  • Size: 20.4 kB
  • Tags: Egg
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.6.7

File hashes

Hashes for m_phate-0.1.0-py3.6.egg
Algorithm Hash digest
SHA256 e0530ac1be2340f6604d153965f16341a9099a664596d9f7b7c6393730d422ad
MD5 eede5970a2082ae117601b61b6b5eca6
BLAKE2b-256 399779efeef45e2c9ed8909f2f1f58fcc1c7495fe94d521a3c4a9afa172cfbe9

See more details on using hashes here.

File details

Details for the file m_phate-0.1.0-py3.5.egg.

File metadata

  • Download URL: m_phate-0.1.0-py3.5.egg
  • Upload date:
  • Size: 20.5 kB
  • Tags: Egg
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.5.6

File hashes

Hashes for m_phate-0.1.0-py3.5.egg
Algorithm Hash digest
SHA256 791c2f72ad442507576d7b978591b97e57c3b1ae30dbb4996032ae58e9e256d0
MD5 5e8e1f219a7c9f670d7c522099e4e807
BLAKE2b-256 53f456f3246b36a044b04daf127517a58dbc9a6ae7506ee35d91e7e8bf50bbf3

See more details on using hashes here.

File details

Details for the file m_phate-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: m_phate-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 23.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.5.6

File hashes

Hashes for m_phate-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 68dd6d84eba43b5c1acc6114630efeae84853e4576dce0502b5c05e476b6d15c
MD5 269c1416e2dbfcddb855e0ca9c1630eb
BLAKE2b-256 299a2fdc414e4aa136088385bf59ac7c5a66b726e32b336db89078fe2f00fefc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page