Skip to main content

A PyTorch implementation of PersLay.

Project description

torchPersLay

torchPersLay is a PyTorch implementation of PersLay, a neural network layer for processing persistence diagrams in topological data analysis (TDA). The original PersLay architecture is available in GUDHI, but only in TensorFlow. This project provides a native, modular, and extensible PyTorch version suitable for modern deep-learning pipelines.

Citation

If you use this neural network layer in your research, please cite the original paper:

PersLay: A Neural Network Layer for Persistence Diagrams and New Graph Topological Signatures
Mathieu Carrière, Frédéric Chazal, Yuichi Ike, Théo Lacombe, Martin Royer, Yuhei Umeda
Proceedings of the Twenty-Third International Conference on Artificial Intelligence and Statistics (AISTATS),
PMLR 108:2786–2796, 2020.

Installation

You may install it from the Python Package Index using

pip install torchperslay

Example Usage

This example usage is concerned with a simple regression model that uses PersLay as a single hidden layer.

Import Packages

Import necessary packages. Ensure that all packages are installed in your system.

from torchPersLay import *

import gudhi.representations as gdr
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler

Define PersLay Layers

Choose the PersLay layer that you want. Consult the official PersLay documentation for the possible options.

constant = 1.0
power = 0.0

weight = PowerPerslayWeight(constant=constant, power=power)
rho = nn.Identity()

image_size = (5, 5)
image_bnds = ((-0.5, 1.5), (-0.5, 1.5))
variance = 0.1

phi = GaussianPerslayPhi(
    image_size=image_size,
    image_bnds=image_bnds,
    variance=variance,
)

perm_op = torch.sum

perslay = Perslay(weight=weight, phi=phi, perm_op=perm_op, rho=rho)

Import Data

This is an example data. Import your own data.

diagrams = [
    np.array([[0.0, 4.0], [1.0, 2.0], [3.0, 8.0], [6.0, 8.0]]),
    np.array([[1.0, 3.0], [2.0, 2.5], [4.0, 7.0], [7.0, 7.5]]),
]

scaler = gdr.DiagramScaler(use=True, scalers=[([0, 1], MinMaxScaler())])
diagrams = scaler.fit_transform(diagrams)
diagrams = torch.from_numpy(np.array(diagrams, dtype=np.float32))

y = torch.tensor([[1.0], [3.0]])

Define Model

This is the creation of the model. This is a simple example. You may add other layers as usual and/or use known architectures to concatenate feature vectors.

class PersLayRegressor(nn.Module):
    def __init__(self, perslay, image_size=(5, 5)):
        super().__init__()
        self.perslay = perslay
        feature_dim = image_size[0] * image_size[1]

        self.regressor = nn.Sequential(
            nn.Linear(feature_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, diagrams):
        x = self.perslay(diagrams)       # [B, 5, 5, 1]
        x = x.view(x.shape[0], -1)       # [B, 25]
        return self.regressor(x)

model = PersLayRegressor(perslay)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

Train Model

This is the model training for PyTorch models.

for epoch in range(100):
    optimizer.zero_grad()
    preds = model(diagrams)
    loss = criterion(preds, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch + 1}, Loss = {loss.item():.6f}")

(Optional) Inspect Learned Parameters

You may optionally see the learned parameters.

for name, param in perslay.named_parameters():
    print(name, param.data)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchperslay-0.1.1.tar.gz (7.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchperslay-0.1.1-py3-none-any.whl (8.1 kB view details)

Uploaded Python 3

File details

Details for the file torchperslay-0.1.1.tar.gz.

File metadata

  • Download URL: torchperslay-0.1.1.tar.gz
  • Upload date:
  • Size: 7.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for torchperslay-0.1.1.tar.gz
Algorithm Hash digest
SHA256 dc29e7fb46eb53a2a0571dd5af23fcf95e10de55e6ff6617b864265633559fbe
MD5 555c3d6f74e266856a2e93bd620d0784
BLAKE2b-256 54b15514fca7e069f02691ab1dd915d975f71e042ff48918465b833d5c95b992

See more details on using hashes here.

File details

Details for the file torchperslay-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: torchperslay-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 8.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for torchperslay-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4fe0420b3e47d4b575ff56d8ef41922148e3175393150124b25fe19cae37a445
MD5 01f4af35c21e2803e0d36f883b1e3cdc
BLAKE2b-256 ea91e45f6ab36dbf62643631d10690b6dfb06304a866c5a39687ba30f2be66fa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page