Fourier Sliced-Wasserstein (FSW) embedding — a PyTorch-based library
Project description
Fourier Sliced-Wasserstein (FSW) embedding — a PyTorch-based library
This package provides an implementation of the Fourier Sliced-Wasserstein (FSW) embedding for multisets and measures, introduced in our ICLR 2025 paper:
Fourier Sliced-Wasserstein Embedding for Multisets and Measures
Tal Amir, Nadav Dym
International Conference on Learning Representations (ICLR), 2025
📦 Requirements
- Python ≥ 3.10.3 (released March 2022)
- PyTorch ≥ 2.1.0 (released October 2023)
- NumPy ≥ 1.24.4 (released June 2023)
The core package has been tested on Linux and Windows.
It may also run on macOS (CPU only), though this has not been verified.
🔧 Installation
To install the package:
pip install fswlib
The core package runs on both CPU and CUDA-enabled GPUs, using PyTorch's standard CUDA backend.
In addition, it includes an optional custom CUDA extension that can provide up to 2× speedup for sparse weight matrices (e.g., sparse graphs). This extension is currently supported only on Linux.
To compile the optional extension, run:
fswlib-build
📘 Usage Example
Below is a basic usage example of the FSWEmbedding class.
For more examples, see the examples/ directory of the GitHub repository.
Full API documentation is available at https://tal-amir.github.io/fswlib.
import torch
from fswlib import FSWEmbedding
# Configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32
d = 15 # Input element dimension
n = 50 # Multiset size
m = 123 # Embedding output dimension
# Create FSW embedding module for multisets/measures over ℝ^d
embed = FSWEmbedding(d_in=d, d_out=m, device=device, dtype=dtype)
# --- Single input multiset ---
X = torch.randn(size=(n, d), device=device, dtype=dtype)
W = torch.rand(n, device=device, dtype=dtype) # Optional weights
X_emb = embed(X, W) # Embeds a weighted multiset
X_emb = embed(X) # Embeds X assuming uniform weights
# --- A batch of input multisets ---
batch_dims = (5,3,7,9)
Xb = torch.randn(size=batch_dims+(n,d), device=device, dtype=dtype)
Wb = torch.rand(batch_dims+(n,), device=device, dtype=dtype)
Xb_emb = embed(Xb, Wb)
print(f"Dimension of multiset elements: {d}")
print(f"Embedding dimension: {m}")
print(f"\nOne multiset X of size {n}:")
print("X shape:", X.shape)
print("embed(X) shape:", X_emb.shape)
batch_dim_str = "×".join(str(b) for b in batch_dims)
print(f"\nBatch of {batch_dim_str} multisets, each of size {n}:")
print("Xb shape:", Xb.shape)
print("embed(Xb) shape:", Xb_emb.shape)
# --- Encoding multiset size (total mass) ---
# By default, the embedding is invariant to the input multiset size, since it treats inputs as *probability measures*.
# Set `encode_total_mass = True` to make the embedding encode the size of the input multisets, or, more generally,
# the total mass (i.e. sum of weights).
embed_total_mass_invariant = FSWEmbedding(d_in=d, d_out=m, device=device, dtype=dtype)
embed_total_mass_aware = FSWEmbedding(d_in=d, d_out=m, encode_total_mass=True, device=device, dtype=dtype)
# Two multisets with identical proportions but different cardinalities
X = torch.rand(3, d, device=device, dtype=dtype)
v1, v2, v3 = X[0], X[1], X[2]
X1 = torch.stack([v1, v2, v3])
X2 = torch.stack([v1, v1, v2, v2, v3, v3])
# Embedding *without* total mass encoding
X1_emb = embed_total_mass_invariant(X1)
X2_emb = embed_total_mass_invariant(X2)
# Embedding *with* total mass encoding
X1_emb_aware = embed_total_mass_aware(X1)
X2_emb_aware = embed_total_mass_aware(X2)
# Measure the differences
diff_invariant = torch.norm(X1_emb - X2_emb).item()
diff_aware = torch.norm(X1_emb_aware - X2_emb_aware).item()
print()
print("Two different-size multisets with identical element proportions:")
print("X₁ = {v₁, v₂, v₃}, X₂ = {v₁, v₁, v₂, v₂, v₃, v₃}")
print("Embedding difference: ‖Embed(X₁) − Embed(X₂)‖₂")
print(f"With total mass encoding: {diff_aware}")
print(f"Without total mass encoding: {diff_invariant:.2e}")
Output:
Dimension of multiset elements: 15
Embedding dimension: 123
One multiset X of size 50:
X shape: torch.Size([50, 15])
embed(X) shape: torch.Size([123])
Batch of 5×3×7×9 multisets, each of size 50:
Xb shape: torch.Size([5, 3, 7, 9, 50, 15])
embed(Xb) shape: torch.Size([5, 3, 7, 9, 123])
Two different-size multisets with identical element proportions:
X₁ = {v₁, v₂, v₃}, X₂ = {v₁, v₁, v₂, v₂, v₃, v₃}
Embedding difference: ‖Embed(X₁) − Embed(X₂)‖₂
With total mass encoding: 3.0
Without total mass encoding: 5.09e-07
📄 Citation
If you use this library in your research, please cite our paper:
@inproceedings{amir2025fsw,
title={Fourier Sliced-{W}asserstein Embedding for Multisets and Measures},
author={Tal Amir and Nadav Dym},
booktitle={International Conference on Learning Representations (ICLR)},
year={2025}
}
🔗 Links
- Paper: ICLR 2025
- Code: GitHub repository
👨🏻🔧 Maintainer
This library is maintained by Tal Amir
Contact: talamir@technion.ac.il
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fswlib-0.1.25.tar.gz.
File metadata
- Download URL: fswlib-0.1.25.tar.gz
- Upload date:
- Size: 52.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
61fbe3f5b78e6fb8f2bf5d6ad4704a9c724e1526be6c36a8663392394d7e79ae
|
|
| MD5 |
994c053964c27ffcb71aaec2f1f94112
|
|
| BLAKE2b-256 |
6f24ebf291d570979a0efd94cdbe793b2bae6da77a740c17308870ee5990a7bf
|
File details
Details for the file fswlib-0.1.25-py3-none-any.whl.
File metadata
- Download URL: fswlib-0.1.25-py3-none-any.whl
- Upload date:
- Size: 53.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
578e8a93faff0de9bf46821dada31972845144d3a82a656235d4284342f270af
|
|
| MD5 |
cba6f809f36ab8a2e79165b23c164c30
|
|
| BLAKE2b-256 |
c4791fbe8c18521883cfcd03097f23fd07889e15e8e6a59c91572a9957891794
|