Skip to main content

FlatNet implementation in PyTorch, from the paper "Representation Learning via Manifold Flattening and Reconstruction"

Project description

This is a minimal pip package to allow easy deployment of FlatNets, a geometry-based neural autoencoder architecture that automatically builds its layers based on geometric properties of the dataset. See our paper, Representation Learning via Manifold Flattening and Reconstruction, for more details, and the Github repo for the code and example scripts & notebooks.

Changelog

0.2.1 (2023-05-09)

Features:

  • Added support for 3D data for gif saving.

0.2.0 (2023-05-07)

Features:

  • Added easy gif saving! If your data is 2D, you can now visualize how your data evolves with the manifold flow.

0.1.4 (2023-05-06)

Bug Fixes:

  • Remembered why I did not add normalization into main training file, now removed

0.1.3 (2023-05-06)

Bug Fixes:

  • Minor performance improvements on simple cases

0.1.1 (2023-05-06)

Features:

  • Added normalization of the input data

0.1.0 (2022-12-31)

  • Initial release of flatnet!

Manifold Linearization for Representation Learning

This is a research project focused on the automatic generation of autoencoders with minimal feature size, when the data is supported near an embedded submanifold. Using the geometric structure of the manifold, we can equivalently treat this problem as a manifold flattening problem when the manifold is flattenable[^1]. See our paper, Representation Learning via Manifold Flattening and Reconstruction, for more details.

[^1]: Geometric note: while flattenability is not general, there are some heuristic reasons we can motivate this assumption for real world data. For example, if a dataset permits a VAE-like autoencoder, where samples from the data distribution can be generated via a standard Gaussian in the latent space, then the samples lie close in probability to a flattenable manifold, as this VAE has constructed a single-chart atlas.

Installation

The pure FlatNet construction (training) code is available as a pip package and can be installed in the following way:

pip install flatnet

This repo also contains a number of testing and illustrative files to both familiarize new users with the framework and show the experiments run in the main paper. To install the appropriate remaining dependencies for this repo, first navigate to the project directory, then run the following command:

pip install -r requirements.txt

Quickstart usage

The follolwing is a simple example script using the pip package:

import torch
import flatnet
import matplotlib.pyplot as plt

# create sine wave dataset
t = torch.linspace(0, 1, 50)
y = torch.sin(t * 2 * 3.14)

# format dataset of N points of dimension D as (N, D) matrix
X = torch.stack([t, y], dim=1)

# normalize data
X = (X - X.mean(dim=0)) / X.std(dim=0)

# f and g are both functions from R^D to R^D
f, g = flatnet.train(X, n_iter=50)

Z = f(X).detach().numpy()

plt.scatter(Z[:,0], Z[:,1])
plt.show()

The script flatnet_test.py includes many example experiments to run FlatNet constructions on. To see an example experiment, simply run python flatnet_test.py in the main directory to see the flattening and reconstruction of a simple sine wave. Further experiments and options can be specified through command line arguments, managed through tyro; to see the full list of arguments, run python flatnet_test.py --help.

Directory Structure

  • flatnet_test.py: main test script, as described in above section.
  • flatnet/train.py: contains the main FlatNet construction (training) code.
  • flatnet/modules: contains code for the neural network modules used in FlatNet.
  • experiments-paper: contains scripts and results from experiments done in the paper.
  • models: contains code for various models that FlatNet was compared against in the paper.
  • tools: contains auxillery tools for evaulating the method, such as random manifold generators.

Citation

If you use this work in your research, please cite the following paper:

@article{psenka2023flatnet,
  author = {Psenka, Michael and Pai, Druv and Raman, Vishal and Sastry, Shankar and Ma, Yi},
  title = {Representation Learning via Manifold Flattening and Reconstruction},
  year = {2023},
  eprint = {2305.01777},
  url = {https://arxiv.org/abs/2305.01777},
}

We hope that you find this project useful. If you have any questions or suggestions, please feel free to contact us.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flatnet-0.2.1.tar.gz (12.5 kB view hashes)

Uploaded Source

Built Distribution

flatnet-0.2.1-py3-none-any.whl (10.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page