Skip to main content

FlatNet implementation in PyTorch, from the paper "Representation Learning via Manifold Flattening and Reconstruction"

Project description

This is a minimal pip package to allow easy deployment of FlatNets, a geometry-based neural autoencoder architecture that automatically builds its layers based on geometric properties of the dataset. See our paper, Representation Learning via Manifold Flattening and Reconstruction, for more details, and the Github repo for the code and example scripts & notebooks.

Changelog

0.2.1 (2023-05-09)

Features:

  • Added support for 3D data for gif saving.

0.2.0 (2023-05-07)

Features:

  • Added easy gif saving! If your data is 2D, you can now visualize how your data evolves with the manifold flow.

0.1.4 (2023-05-06)

Bug Fixes:

  • Remembered why I did not add normalization into main training file, now removed

0.1.3 (2023-05-06)

Bug Fixes:

  • Minor performance improvements on simple cases

0.1.1 (2023-05-06)

Features:

  • Added normalization of the input data

0.1.0 (2022-12-31)

  • Initial release of flatnet!

Manifold Linearization for Representation Learning

This is a research project focused on the automatic generation of autoencoders with minimal feature size, when the data is supported near an embedded submanifold. Using the geometric structure of the manifold, we can equivalently treat this problem as a manifold flattening problem when the manifold is flattenable[^1]. See our paper, Representation Learning via Manifold Flattening and Reconstruction, for more details.

[^1]: Geometric note: while flattenability is not general, there are some heuristic reasons we can motivate this assumption for real world data. For example, if a dataset permits a VAE-like autoencoder, where samples from the data distribution can be generated via a standard Gaussian in the latent space, then the samples lie close in probability to a flattenable manifold, as this VAE has constructed a single-chart atlas.

Installation

The pure FlatNet construction (training) code is available as a pip package and can be installed in the following way:

pip install flatnet

This repo also contains a number of testing and illustrative files to both familiarize new users with the framework and show the experiments run in the main paper. To install the appropriate remaining dependencies for this repo, first navigate to the project directory, then run the following command:

pip install -r requirements.txt

Quickstart usage

The follolwing is a simple example script using the pip package:

import torch
import flatnet
import matplotlib.pyplot as plt

# create sine wave dataset
t = torch.linspace(0, 1, 50)
y = torch.sin(t * 2 * 3.14)

# format dataset of N points of dimension D as (N, D) matrix
X = torch.stack([t, y], dim=1)

# normalize data
X = (X - X.mean(dim=0)) / X.std(dim=0)

# f and g are both functions from R^D to R^D
f, g = flatnet.train(X, n_iter=50)

Z = f(X).detach().numpy()

plt.scatter(Z[:,0], Z[:,1])
plt.show()

The script flatnet_test.py includes many example experiments to run FlatNet constructions on. To see an example experiment, simply run python flatnet_test.py in the main directory to see the flattening and reconstruction of a simple sine wave. Further experiments and options can be specified through command line arguments, managed through tyro; to see the full list of arguments, run python flatnet_test.py --help.

Directory Structure

  • flatnet_test.py: main test script, as described in above section.
  • flatnet/train.py: contains the main FlatNet construction (training) code.
  • flatnet/modules: contains code for the neural network modules used in FlatNet.
  • experiments-paper: contains scripts and results from experiments done in the paper.
  • models: contains code for various models that FlatNet was compared against in the paper.
  • tools: contains auxillery tools for evaulating the method, such as random manifold generators.

Citation

If you use this work in your research, please cite the following paper:

@article{psenka2023flatnet,
  author = {Psenka, Michael and Pai, Druv and Raman, Vishal and Sastry, Shankar and Ma, Yi},
  title = {Representation Learning via Manifold Flattening and Reconstruction},
  year = {2023},
  eprint = {2305.01777},
  url = {https://arxiv.org/abs/2305.01777},
}

We hope that you find this project useful. If you have any questions or suggestions, please feel free to contact us.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flatnet-0.2.1.tar.gz (12.5 kB view details)

Uploaded Source

Built Distribution

flatnet-0.2.1-py3-none-any.whl (10.8 kB view details)

Uploaded Python 3

File details

Details for the file flatnet-0.2.1.tar.gz.

File metadata

  • Download URL: flatnet-0.2.1.tar.gz
  • Upload date:
  • Size: 12.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for flatnet-0.2.1.tar.gz
Algorithm Hash digest
SHA256 278dc9e8fe7c0e1bf95b355adbc162bc440321f0b21e733fa4700e8d990daee9
MD5 22eafbcc768297c5077fae7ebdbcf5b2
BLAKE2b-256 af5d3ea17e24481dcfcfb17cbb758da37c530816e71481e8368e6dec8d87025c

See more details on using hashes here.

File details

Details for the file flatnet-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: flatnet-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 10.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for flatnet-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 531e0b1f1ce13c5962fcce58365d24c57c5caaa4537d38c3240be9bdc44435f6
MD5 43739ad6d8791acaed0465ec71175278
BLAKE2b-256 a87a322aa766297bfbc9c6237917624712958d8a152be08892fc569bf1b82be3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page