Skip to main content

GPU-accelerated Archetypal Analysis implementation using JAX

Project description

ArchetypAX

ArchetypAX - Hardware-accelerated Archetypal Analysis implementation using JAX

PyPI version License Tests Lint

Overview

archetypax is a high-performance implementation of Archetypal Analysis (AA) that leverages JAX for GPU acceleration.
Archetypal Analysis is a matrix factorization technique that represents data points
as convex combinations of extreme points (archetypes) found within the data's convex hull.

Unlike traditional dimensionality reduction techniques like PCA which finds a basis of orthogonal components,
AA finds interpretable extremal points that often correspond to meaningful prototypes in the data.

Features

  • 🚀 GPU/TPU Acceleration: Utilizes JAX for fast computation on GPUs
  • 🔍 Interpretable Results: Finds meaningful archetypes that represent extremal patterns in data
  • 🧠 Smart Initialization: Uses k-means++ style initialization for better convergence
  • 🛠️ Numerical Stability: Implements various techniques for improved stability
  • 📊 scikit-learn Compatible API: Implements the familiar fit/transform interface

Installation

# Using pip
pip install archetypax

or from GitHub:

pip install git+https://github.com/lv416e/archetypax.git

Using uv

uv pip install archetypax
# or from GitHub
uv pip install git+https://github.com/lv416e/archetypax.git

Using Poetry

poetry add archetypax

# or from GitHub
poetry add git+https://github.com/lv416e/archetypax.git

Requirements

  • Python 3.10+
  • JAX
  • NumPy
  • scikit-learn

Quick Start

import numpy as np
from archetypax import ArchetypalAnalysis

# Generate sample data
np.random.seed(42)
X = np.random.rand(1000, 10)

# Initialize and fit the model
model = ArchetypalAnalysis(n_archetypes=5)
weights = model.fit_transform(X)

# Get the archetypes
archetypes = model.archetypes

# Reconstruct the data
X_reconstructed = model.reconstruct()

# Calculate reconstruction error
mse = np.mean((X - X_reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.6f}")

Import Patterns

ArchetypAX supports multiple import patterns for flexibility:

Direct Class Imports (Recommended)

from archetypax import ArchetypalAnalysis, ImprovedArchetypalAnalysis, BiarchetypalAnalysis

Explicit Module Imports

from archetypax.models.base import ArchetypalAnalysis
from archetypax.models.biarchetypes import BiarchetypalAnalysis
from archetypax.tools.evaluation import ArchetypalAnalysisEvaluator

Module-Level Imports

from archetypax.models import ArchetypalAnalysis
from archetypax.tools import ArchetypalAnalysisVisualizer

Documentation

Parameters

ArchetypalAnalysis / ImprovedArchetypalAnalysis

  • n_archetypes: Number of archetypes to find
  • max_iter: Maximum number of iterations (default: 500)
  • tol: Convergence tolerance (default: 1e-6)
  • random_seed: Random seed for initialization (default: 42)
  • learning_rate: Learning rate for optimizer (default: 0.001)

BiarchetypalAnalysis

  • n_archetypes_first: Number of archetypes in the first set
  • n_archetypes_second: Number of archetypes in the second set
  • mixture_weight: Weight for mixing the two archetype sets (0-1) (default: 0.5)
  • max_iter: Maximum number of iterations (default: 500)
  • tol: Convergence tolerance (default: 1e-6)
  • random_seed: Random seed for initialization (default: 42)
  • learning_rate: Learning rate for optimizer (default: 0.001)

Methods

  • fit(X): Fit the model to the data
  • transform(X): Transform new data to archetype weights
  • fit_transform(X): Fit the model and transform the data
  • reconstruct(X): Reconstruct data from archetype weights
  • get_loss_history(): Get the loss history from training
  • get_all_archetypes(): Get both sets of archetypes (BiarchetypalAnalysis only)
  • get_all_weights(): Get both sets of weights (BiarchetypalAnalysis only)

Examples

Visualizing Archetypes in 2D Data

import numpy as np
import matplotlib.pyplot as plt
from archetypax import ImprovedArchetypalAnalysis as ArchetypalAnalysis
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer

# Generate some interesting 2D data (a triangle with points inside)
n_samples = 500
vertices = np.array([[0, 0], [1, 0], [0.5, 0.866]])
weights = np.random.dirichlet(np.ones(3), size=n_samples)
X = weights @ vertices

# Fit archetypal analysis with 3 archetypes
model = ArchetypalAnalysis(n_archetypes=3)
model.fit(X)

# Plot original data and archetypes
plt.figure(figsize=(10, 8))
ArchetypalAnalysisVisualizer.plot_archetypes_2d(model, X)
plt.title("Archetypal Analysis of 2D Data")
plt.show()

Using Biarchetypal Analysis

import numpy as np
import matplotlib.pyplot as plt
from archetypax import BiarchetypalAnalysis
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer

# Generate synthetic data
np.random.seed(42)
X = np.random.rand(500, 5)

# Initialize and fit the model with two sets of archetypes
model = BiarchetypalAnalysis(
    n_archetypes_first=2,   # Number of archetypes in the first set
    n_archetypes_second=2,  # Number of archetypes in the second set
    mixture_weight=0.5,     # Weight for mixing the two archetype sets (0-1)
    max_iter=500,
    random_seed=42
)
model.fit(X)

# Get both sets of archetypes
positive_archetypes, negative_archetypes = model.get_all_archetypes()
print("Positive archetypes shape:", positive_archetypes.shape)
print("Negative archetypes shape:", negative_archetypes.shape)

# Get both sets of weights
positive_weights, negative_weights = model.get_all_weights()
print("Positive weights shape:", positive_weights.shape)
print("Negative weights shape:", negative_weights.shape)

# Reconstruct data using both sets of archetypes
X_reconstructed = model.reconstruct()
mse = np.mean((X - X_reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.6f}")

How It Works

Archetypal Analysis solves the following optimization problem:

Given a data matrix $\mathbf{X} \in \mathbb{R}^{n \times d}$ with n samples and d features, find k archetypes $\mathbf{Z} \in \mathbb{R}^{k \times p}$ and weights $\mathbf{w} \in \mathbb{R}^{n \times k}$ such that:

$$ \text{minimize} | \mathbf{X} - \mathbf{w} \cdot \mathbf{Z} |^2_{\text{F}} $$

subject to:

  • $\mathbf{w}$ is non-negative
  • Each row of $\mathbf{w}$ sums to 1 (simplex constraint)
  • $\mathbf{Z}$ lies within the convex hull of $\mathbf{X}$

This implementation uses JAX's automatic differentiation and optimization tools to efficiently solve this problem on GPUs. It also incorporates several enhancements:

  1. k-means++ style initialization for better initial archetype positions
  2. Entropy regularization to promote more uniform weight distributions
  3. Soft archetype projection using k-nearest neighbors for improved stability
  4. Gradient clipping to prevent numerical issues

Citation

If you use this package in your research, please cite:

@software{archetypax2025,
  author = {mary},
  title = {archetypax: GPU-accelerated Archetypal Analysis using JAX},
  year = {2025},
  url = {https://github.com/lv416e/archetypax}
}

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

archetypax-0.1.0.dev2.tar.gz (49.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

archetypax-0.1.0.dev2-py3-none-any.whl (46.2 kB view details)

Uploaded Python 3

File details

Details for the file archetypax-0.1.0.dev2.tar.gz.

File metadata

  • Download URL: archetypax-0.1.0.dev2.tar.gz
  • Upload date:
  • Size: 49.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for archetypax-0.1.0.dev2.tar.gz
Algorithm Hash digest
SHA256 2b68fcbc15958e1d94c66a60c4e203f19d9e39b3e68f4dc6c8ee51baeba189fc
MD5 3b16bbc465189bc970d8c938b606db68
BLAKE2b-256 8c2d7f003fce3c6dd0e801eae9ed01b213b7f0fe16a08b11a11b230a087f5ab4

See more details on using hashes here.

Provenance

The following attestation bundles were made for archetypax-0.1.0.dev2.tar.gz:

Publisher: release.yml on lv416e/archetypax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file archetypax-0.1.0.dev2-py3-none-any.whl.

File metadata

File hashes

Hashes for archetypax-0.1.0.dev2-py3-none-any.whl
Algorithm Hash digest
SHA256 fc1b531cddbb87a56a9679fd9ab60d7b95cf54a39bd6dbf8113d261d6c2288a7
MD5 0d0cb9ef1df352c770bcec8f03a2b942
BLAKE2b-256 75cfb9caf870ef4ac0d037ce49748d9b4b28059e26b6386b9043440448a861c2

See more details on using hashes here.

Provenance

The following attestation bundles were made for archetypax-0.1.0.dev2-py3-none-any.whl:

Publisher: release.yml on lv416e/archetypax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page