Skip to main content

A PyTorch implementation of PersLay.

Project description

torchPersLay

torchPersLay is a PyTorch implementation of PersLay, a neural network layer for processing persistence diagrams in topological data analysis (TDA). The original PersLay architecture is available in GUDHI, but only in TensorFlow. This project provides a native, modular, and extensible PyTorch version suitable for modern deep-learning pipelines.

Citation

If you use this neural network layer in your research, please cite the original paper:

PersLay: A Neural Network Layer for Persistence Diagrams and New Graph Topological Signatures
Mathieu Carrière, Frédéric Chazal, Yuichi Ike, Théo Lacombe, Martin Royer, Yuhei Umeda
Proceedings of the Twenty-Third International Conference on Artificial Intelligence and Statistics (AISTATS),
PMLR 108:2786–2796, 2020.

Installation

You may install it from the Python Package Index using

pip install torchperslay

Example Usage

This example usage is concerned with a simple regression model that uses PersLay as a single hidden layer.

Import Packages

Import necessary packages. Ensure that all packages are installed in your system.

from torchPersLay import *

import gudhi.representations as gdr
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler

Define PersLay Layers

Choose the PersLay layer that you want. Consult the official PersLay documentation for the possible options.

constant = 1.0
power = 0.0

weight = PowerPerslayWeight(constant=constant, power=power)
rho = nn.Identity()

image_size = (5, 5)
image_bnds = ((-0.5, 1.5), (-0.5, 1.5))
variance = 0.1

phi = GaussianPerslayPhi(
    image_size=image_size,
    image_bnds=image_bnds,
    variance=variance,
)

perm_op = torch.sum

perslay = Perslay(weight=weight, phi=phi, perm_op=perm_op, rho=rho)

Import Data

This is an example data. Import your own data.

diagrams = [
    np.array([[0.0, 4.0], [1.0, 2.0], [3.0, 8.0], [6.0, 8.0]]),
    np.array([[1.0, 3.0], [2.0, 2.5], [4.0, 7.0], [7.0, 7.5]]),
]

scaler = gdr.DiagramScaler(use=True, scalers=[([0, 1], MinMaxScaler())])
diagrams = scaler.fit_transform(diagrams)
diagrams = torch.from_numpy(np.array(diagrams, dtype=np.float32))

y = torch.tensor([[1.0], [3.0]])

Define Model

This is the creation of the model. This is a simple example. You may add other layers as usual and/or use known architectures to concatenate feature vectors.

class PersLayRegressor(nn.Module):
    def __init__(self, perslay, image_size=(5, 5)):
        super().__init__()
        self.perslay = perslay
        feature_dim = image_size[0] * image_size[1]

        self.regressor = nn.Sequential(
            nn.Linear(feature_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, diagrams):
        x = self.perslay(diagrams)       # [B, 5, 5, 1]
        x = x.view(x.shape[0], -1)       # [B, 25]
        return self.regressor(x)

model = PersLayRegressor(perslay)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

Train Model

This is the model training for PyTorch models.

for epoch in range(100):
    optimizer.zero_grad()
    preds = model(diagrams)
    loss = criterion(preds, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch + 1}, Loss = {loss.item():.6f}")

(Optional) Inspect Learned Parameters

You may optionally see the learned parameters.

for name, param in perslay.named_parameters():
    print(name, param.data)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchperslay-0.1.3.tar.gz (7.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchperslay-0.1.3-py3-none-any.whl (7.9 kB view details)

Uploaded Python 3

File details

Details for the file torchperslay-0.1.3.tar.gz.

File metadata

  • Download URL: torchperslay-0.1.3.tar.gz
  • Upload date:
  • Size: 7.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for torchperslay-0.1.3.tar.gz
Algorithm Hash digest
SHA256 1ce5f867d38d8185e3004a8575505f22b495de2bed7ceb76c39909eb498b26c1
MD5 fa050fcdfa413f0016410f2436cb1d02
BLAKE2b-256 6fe463b5a45a81f1b7459e156557dd8ac695e054a60a32cce6af77dd619ae399

See more details on using hashes here.

File details

Details for the file torchperslay-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: torchperslay-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 7.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for torchperslay-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 c3c419f7b373fdf60a9bd4b025eca484903e39d33da14044c5de86e1e4884fe3
MD5 d7b8092c66a6f9f498483420d85ed67e
BLAKE2b-256 8bab73665845b12f6b9b2f645f8bc951988ef045d5a2ec60d374892282de2f1d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page