Skip to main content

A PyTorch implementation of PersLay.

Project description

torchPersLay

torchPersLay is a PyTorch implementation of PersLay, a neural network layer for processing persistence diagrams in topological data analysis (TDA). The original PersLay architecture is available in GUDHI, but only in TensorFlow. This project provides a native, modular, and extensible PyTorch version suitable for modern deep-learning pipelines.

Citation

If you use this neural network layer in your research, please cite the original paper:

PersLay: A Neural Network Layer for Persistence Diagrams and New Graph Topological Signatures
Mathieu Carrière, Frédéric Chazal, Yuichi Ike, Théo Lacombe, Martin Royer, Yuhei Umeda
Proceedings of the Twenty-Third International Conference on Artificial Intelligence and Statistics (AISTATS),
PMLR 108:2786–2796, 2020.

Installation

You may install it from the Python Package Index using

pip install torchperslay

Example Usage

This example usage is concerned with a simple regression model that uses PersLay as a single hidden layer.

Import Packages

Import necessary packages. Ensure that all packages are installed in your system.

from torchPersLay import *

import gudhi.representations as gdr
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler

Define PersLay Layers

Choose the PersLay layer that you want. Consult the official PersLay documentation for the possible options.

constant = 1.0
power = 0.0

weight = PowerPerslayWeight(constant=constant, power=power)
rho = nn.Identity()

image_size = (5, 5)
image_bnds = ((-0.5, 1.5), (-0.5, 1.5))
variance = 0.1

phi = GaussianPerslayPhi(
    image_size=image_size,
    image_bnds=image_bnds,
    variance=variance,
)

perm_op = torch.sum

perslay = Perslay(weight=weight, phi=phi, perm_op=perm_op, rho=rho)

Import Data

This is an example data. Import your own data.

diagrams = [
    np.array([[0.0, 4.0], [1.0, 2.0], [3.0, 8.0], [6.0, 8.0]]),
    np.array([[1.0, 3.0], [2.0, 2.5], [4.0, 7.0], [7.0, 7.5]]),
]

scaler = gdr.DiagramScaler(use=True, scalers=[([0, 1], MinMaxScaler())])
diagrams = scaler.fit_transform(diagrams)
diagrams = torch.from_numpy(np.array(diagrams, dtype=np.float32))

y = torch.tensor([[1.0], [3.0]])

Define Model

This is the creation of the model. This is a simple example. You may add other layers as usual and/or use known architectures to concatenate feature vectors.

class PersLayRegressor(nn.Module):
    def __init__(self, perslay, image_size=(5, 5)):
        super().__init__()
        self.perslay = perslay
        feature_dim = image_size[0] * image_size[1]

        self.regressor = nn.Sequential(
            nn.Linear(feature_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, diagrams):
        x = self.perslay(diagrams)       # [B, 5, 5, 1]
        x = x.view(x.shape[0], -1)       # [B, 25]
        return self.regressor(x)

model = PersLayRegressor(perslay)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

Train Model

This is the model training for PyTorch models.

for epoch in range(100):
    optimizer.zero_grad()
    preds = model(diagrams)
    loss = criterion(preds, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch + 1}, Loss = {loss.item():.6f}")

(Optional) Inspect Learned Parameters

You may optionally see the learned parameters.

for name, param in perslay.named_parameters():
    print(name, param.data)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchperslay-0.1.2.tar.gz (4.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchperslay-0.1.2-py3-none-any.whl (4.5 kB view details)

Uploaded Python 3

File details

Details for the file torchperslay-0.1.2.tar.gz.

File metadata

  • Download URL: torchperslay-0.1.2.tar.gz
  • Upload date:
  • Size: 4.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for torchperslay-0.1.2.tar.gz
Algorithm Hash digest
SHA256 511fc03f23730c257b9b81808ae6dc9d2d55974fe4e6a1fd66f56ff81d64e479
MD5 bcb3283fea8ac100bd6644cfd6e3a676
BLAKE2b-256 f8c9c94d6f25526c91b5d0409ef8a2d77743ab1a3fb0a9bef7d381573c116d74

See more details on using hashes here.

File details

Details for the file torchperslay-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: torchperslay-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 4.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for torchperslay-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 af0103fecae9b6e9a5cd2f7ac7f31dc9a4aead1bb9c4133663ed3ac58000333d
MD5 8d741fe7021b9b604efc806f9f3301be
BLAKE2b-256 cfc20d5b720e48c7848539c04dfb99a2085fe675289d213becd150cddf4d0001

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page