Skip to main content

Fourier Sliced-Wasserstein (FSW) embedding — a PyTorch-based library

Project description

Fourier Sliced-Wasserstein (FSW) embedding — a PyTorch-based library

This package provides an implementation of the Fourier Sliced-Wasserstein (FSW) embedding for multisets and measures, introduced in our ICLR 2025 paper:

Fourier Sliced-Wasserstein Embedding for Multisets and Measures
Tal Amir, Nadav Dym
International Conference on Learning Representations (ICLR), 2025


📦 Requirements

  • Python ≥ 3.10.3 (released March 2022)
  • PyTorch ≥ 2.1.0 (released October 2023)
  • NumPy ≥ 1.24.4 (released June 2023)

The core package has been tested on Linux and Windows.
It may also run on macOS (CPU only), though this has not been verified.


🔧 Installation

To install the package:

pip install fswlib

The core package runs on both CPU and CUDA-enabled GPUs, using PyTorch's standard CUDA backend.

In addition, it includes an optional custom CUDA extension that can provide up to 2× speedup for sparse weight matrices (e.g., sparse graphs). This extension is currently supported only on Linux.

To compile the optional extension, run:

fswlib-build

📘 Usage Example

Below is a basic usage example of the FSWEmbedding class.

For more examples, see the examples/ directory.
Full API documentation is available at https://tal-amir.github.io/fswlib.

import torch

from fswlib import FSWEmbedding

dtype=torch.float32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

d = 15  # dimension of input multiset elements
n = 50  # multiset size
m = 123 # embedding output dimension

# If False, input multisets are treated as uniform distributions over their elements,
# making the embedding invariant to the multiset size.
encode_total_mass = True

# Generate an embedding module
embed = FSWEmbedding(d_in=d, d_out=m, encode_total_mass=encode_total_mass, device=device, dtype=dtype)

# Generate and embed one multiset
X = torch.randn(size=(n,d), dtype=dtype, device=device)
X_emb = embed(X)

# Generate and embed a batch of multisets
# Supports input with any number of batch dimensions
batch_dims = (5,3,4)
Xb = torch.randn(size=batch_dims+(n,d), dtype=dtype, device=device)
Xb_emb = embed(Xb)

print(f"Dimension of multiset elements: {d}\nEmbedding dimension: {m}")
print(f'\nOne input multiset X of size {n}:')
print('Shape of X: ', X.shape)
print('Shape of embed(X): ', X_emb.shape)
batch_dim_str = "×".join(str(d) for d in batch_dims)
print(f'\nA batch Xb of {batch_dim_str} input multisets, each of size {n}: ')
print('Shape of Xb: ', Xb.shape)
print('Shape of embed(Xb): ', Xb_emb.shape)

Output:

Dimension of multiset elements: 15
Embedding dimension: 123

One input multiset X of size 50:
Shape of X:  torch.Size([50, 15])
Shape of embed(X):  torch.Size([123])

A batch Xb of 5×3×4 input multisets, each of size 50:
Shape of Xb:  torch.Size([5, 3, 4, 50, 15])
Shape of embed(Xb):  torch.Size([5, 3, 4, 123])

📄 Citation

If you use this library in your research, please cite our paper:

@inproceedings{amir2025fsw,
  title={Fourier Sliced-{W}asserstein Embedding for Multisets and Measures},
  author={Tal Amir and Nadav Dym},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2025}
}

🔗 Links


👨🏻‍🔧 Maintainer

This library is maintained by Tal Amir
Contact: talamir@technion.ac.il

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fswlib-0.1.24.tar.gz (50.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fswlib-0.1.24-py3-none-any.whl (51.6 kB view details)

Uploaded Python 3

File details

Details for the file fswlib-0.1.24.tar.gz.

File metadata

  • Download URL: fswlib-0.1.24.tar.gz
  • Upload date:
  • Size: 50.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for fswlib-0.1.24.tar.gz
Algorithm Hash digest
SHA256 686e350599e5cd48f4c7b20ac583a5f5c3e92f9ead605c1c1e1e30d7fa85c048
MD5 68b6a0b4272e399e748fb7e18a9e9c52
BLAKE2b-256 2aeb6a51844ac3bca747e027477129f9993fa3f7ced4f748781c222252fcfd9b

See more details on using hashes here.

File details

Details for the file fswlib-0.1.24-py3-none-any.whl.

File metadata

  • Download URL: fswlib-0.1.24-py3-none-any.whl
  • Upload date:
  • Size: 51.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for fswlib-0.1.24-py3-none-any.whl
Algorithm Hash digest
SHA256 72e468aa3765b302f5a0ad79ddee959db950f07dfe28ba36a3d0950e0e2c726b
MD5 e424cdda17c6037c4b3c6062cf51387f
BLAKE2b-256 e25dd7aca11e7da1325b4ccbd51aecc268dc1b0b282349a46ae13718cf7517ab

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page