Skip to main content

Implementation of the smoothed Weiszfeld algorithm to compute the geometric median

Project description

Differentiable and Fast Geometric Median in NumPy and PyTorch

This package implements a fast numerical algorithm to compute the geometric median of high dimensional vectors. As a generalization of the median (of scalars), the geometric median is a robust estimator of the mean in the presence of outliers and contaminations (adversarial or otherwise).

definition

The geometric median is also known as the Fermat point, Weber's L1 median, Fréchet median among others. It has a breakdown point of 0.5, meaning that it yields a robust aggregate even under arbitrary corruptions to points accounting for under half the total weight. We use the smoothed Weiszfeld algorithm to compute the geometric median.

Features:

  • Implementation in both NumPy and PyTorch.
  • PyTorch implementation is fully differentiable (compatible with gradient backpropagation a.k.a. automatic differentiation) and can run on GPUs with CUDA tensors.
  • Blazing fast algorithm that converges linearly in almost all practical settings.

Installation

This package can be installed via pip as pip install geom_median. Alternatively, for an editable install, run

git clone git@github.com:krishnap25/geom_median.git
cd geom_median
pip install -e .

You must have a working installation of PyTorch, version 1.7 or over in case you wish to use the PyTorch API. See details here.

Usage Guide

We describe the PyTorch usage here. The NumPy API is entirely analogous.

import torch
from geom_median.torch import compute_geometric_median   # PyTorch API
# from geom_median.numpy import compute_geometric_median  # NumPy API

For the simplest use case, supply a list of tensors:

n = 10  # Number of vectors
d = 25  # dimensionality of each vector
points = [torch.rand(d) for _ in range(n)]   # list of n tensors of shape (d,)
# The shape of each tensor is the same and can be arbitrary (not necessarily 1-dimensional)
weights = torch.rand(n)  # non-negative weights of shape (n,)
out = compute_geometric_median(points, weights)
# Access the median via `out.median`, which has the same shape as the points, i.e., (d,)

The termination condition can be examined through out.termination, which gives a message such as "function value converged within tolerance" or "maximum iterations reached".

We also support a use case where each point is given by list of tensors. For instance, each point is the list of parameters of a torch.nn.Module for instance as point = list(module.parameters()). In this case, this is equivalent to flattening and concatenating all the tensors into a single vector via flatted_point = torch.stack([v.view(-1) for v in point]). This functionality can be invoked as follows:

models = [torch.nn.Linear(20, 10) for _ in range(n)]  # a list of n models
points = [list(model.parameters()) for model in models]  # list of points, where each point is a list of tensors
out = compute_geometric_median(points, weights=None)  # equivalent to `weights = torch.ones(n)`. 
# Access the median via `out.median`, also given as a list of tensors

We also support computing the geometric median for each component separately in the list-of-tensors format:

models = [torch.nn.Linear(20, 10) for _ in range(n)]  # a list of n models
points = [list(model.parameters()) for model in models]  # list of points, where each point is a list of tensors
out = compute_geometric_median(points, weights=None, per_component=True)  
# Access the median via `out.median`, also given as a list of tensors

This per-component geometric median is equivalent in functionality to

out.median[j] = compute_geometric_median([p[j] for p in points], weights)

Backpropagation support

When using the PyTorch API, the result out.median, as a function of points, supports gradient backpropagation, also known as reverse-mode automatic differentiation. Here is a toy example illustrating this behavior.

points = [torch.rand(d).requires_grad_(True) for _ in range(n)]   # list of tensors with `requires_grad=True`
out = compute_geometric_median(points, weights=None)
torch.linalg.norm(out.median).backward()  # call backward on any downstream function of `out.median`
gradients = [p.grad for p in points]  # gradients with respect of `points` and upstream nodes in the computation graph

GPU support

Simply use as above where points and weights are CUDA tensors.

Authors and Contact

Krishna Pillutla
Sham Kakade
Zaid Harchaoui

In case of questions, please raise an issue on GitHub.

Citation

If you found this package useful, please consider citing this paper.

@article{pillutla:etal:rfa ,
  title={{Robust Aggregation for Federated Learning}},
  author={Pillutla, Krishna and  Kakade, Sham M. and Harchaoui, Zaid},
  journal={arXiv preprint},
  year={2019}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

geom_median-0.1.0.tar.gz (20.9 kB view details)

Uploaded Source

File details

Details for the file geom_median-0.1.0.tar.gz.

File metadata

  • Download URL: geom_median-0.1.0.tar.gz
  • Upload date:
  • Size: 20.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.0.0 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.6.7

File hashes

Hashes for geom_median-0.1.0.tar.gz
Algorithm Hash digest
SHA256 5a5bd8c930fde58febcf00aab9ecea5f12a54bc75589e5063754deb5b6e90495
MD5 75e1294eb82824e5771fcab060b0ad53
BLAKE2b-256 0eb913101da828812b56df9e8fa12a62f24408dcb416edcd6b3b091c973306b3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page