Skip to main content

GPU-accelerated Archetypal Analysis implementation using JAX

Project description

ArchetypAX – Hardware-accelerated Archetypal Analysis with JAX

Discover extreme patterns in your data with GPU/TPU-accelerated Archetypal Analysis, high-performance convex hull optimization, and interpretable matrix factorization.

License PyPI Tests Lint Docs Release

Table of Contents

Overview

archetypax is a high-performance implementation of Archetypal Analysis (AA) leveraging JAX for GPU acceleration.

Archetypal Analysis is a powerful matrix factorization technique representing data points
as convex combinations of extreme points (archetypes) found within the data's convex hull.

Unlike traditional dimensionality reduction techniques like PCA, which finds abstract orthogonal components,
AA discovers interpretable extremal points often corresponding to meaningful prototypes.

This makes it valuable for applications requiring both dimensionality reduction and human-interpretable insights,
such as market segmentation, document analysis, and anomaly detection.

Features

Performance & Stability

  • 🚀 GPU/TPU acceleration using JAX
  • 🧠 Smart initialization (k-means++, directional)
  • 🛠️ Numerical stability & convergence techniques

Usability & Compatibility

  • 📊 scikit-learn compatible API (fit/transform)
  • 📋 Thorough documentation

Interpretability & Visualization

  • 🔍 Meaningful interpretable archetypes
  • 📈 Advanced tracking & optimization trajectory monitoring
  • 🎯 Comprehensive evaluation & visualization tooling

Related Projects and Techniques

ArchetypAX can be used alongside or compared with these related approaches:

  • PCA: Principal Component Analysis finds orthogonal directions of maximum variance
  • NMF: Non-negative Matrix Factorization decomposes data into non-negative components
  • k-means: Clustering technique that partitions data into k clusters
  • JAX Ecosystem: Compatible with JAX-based machine learning frameworks like Flax
  • scikit-learn: Follows similar API conventions, allowing easy integration

Installation

Install with pip, uv, or poetry:

# pip
pip install archetypax
pip install git+https://github.com/lv416e/archetypax.git

# uv
uv pip install archetypax
uv pip install git+https://github.com/lv416e/archetypax.git

# poetry
poetry add archetypax
poetry add git+https://github.com/lv416e/archetypax.git

Install optional dependencies:

pip install archetypax[dev]       # Development dependencies
pip install archetypax[examples]  # Example dependencies
pip install archetypax[docs]      # Documentation dependencies

Requirements

Type Dependency Version Description
Core Python >=3.10 Required for modern language features and compatibility with JAX
Core JAX >=0.4.0 Powers the hardware acceleration and automatic differentiation
Core NumPy >=1.20.0 Handles core numerical operations and array manipulations
Core optax >=0.1.0 JAX-based optimization framework for gradient-based updates
Core pandas >=1.3.0 Data manipulation and analysis library
Core scikit-learn >=1.0.0 Provides machine learning utilities and compatible interfaces
Examples jupyter >=1.0.0 Interactive computing environment for notebooks
Examples matplotlib >=3.7.5 Required for visualization functionality
Examples seaborn >=0.13.2 Statistical data visualization
Dev black ==23.7.0 Code formatter
Dev mypy >=1.8.0 Static type checker
Dev pytest >=7.0.0 Testing framework
Dev ruff >=0.9.0 Fast Python linter and formatter

Quick Start

import numpy as np
from archetypax import ImprovedArchetypalAnalysis as ArchetypalAnalysis

# Generate sample data
np.random.seed(42)
X = np.random.rand(1000, 10)

# Initialize and fit the model
model = ArchetypalAnalysis(n_archetypes=5)
weights = model.fit_transform(X)

# Get the archetypes
archetypes = model.archetypes

# Reconstruct the data
X_reconstructed = model.reconstruct()

# Calculate reconstruction error
mse = np.mean((X - X_reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.6f}")

Import Patterns

ArchetypAX supports multiple import patterns for flexibility:

Direct Class Imports (Recommended)

from archetypax import ArchetypalAnalysis, ImprovedArchetypalAnalysis, BiarchetypalAnalysis, ArchetypeTracker

Explicit Module Imports

from archetypax.models.base import ArchetypalAnalysis
from archetypax.models.biarchetypes import BiarchetypalAnalysis
from archetypax.tools.evaluation import ArchetypalAnalysisEvaluator
from archetypax.tools.tracker import ArchetypeTracker

Module-Level Imports

from archetypax.models import ArchetypalAnalysis
from archetypax.tools import ArchetypalAnalysisVisualizer, ArchetypeTracker

Changelog

For a detailed list of changes and version history, please see the CHANGELOG.md file.

Documentation

Parameters

ArchetypalAnalysis / ImprovedArchetypalAnalysis

Parameter Type Default Description
n_archetypes int - Number of archetypes to find
max_iter int 500 Maximum number of iterations
tol float 1e-6 Convergence tolerance
random_seed int 42 Random seed for initialization
learning_rate float 0.001 Learning rate for optimizer
lambda_reg float 0.01 Regularization strength for weight distribution
normalize bool False Whether to normalize features before fitting
projection_method str "cbap" Method for projecting archetypes ("cbap", "convex_hull", "knn")
projection_alpha float 0.1 Blending coefficient for boundary projection
archetype_init_method str "directional" Initialization strategy ("directional", "kmeans++", "qhull")

BiarchetypalAnalysis

Parameter Type Default Description
n_row_archetypes int - Number of archetypes in observation space
n_col_archetypes int - Number of archetypes in feature space
max_iter int 500 Maximum number of iterations
tol float 1e-6 Convergence tolerance
random_seed int 42 Random seed for initialization
learning_rate float 0.001 Learning rate for optimizer
projection_method str "default" Method for projecting archetypes
lambda_reg float 0.01 Regularization strength for entropy terms

Methods

Method Returns Description
fit(X) model Fit the model to the data
transform(X) array Transform new data to archetype weights
fit_transform(X) array Fit the model and transform the data
reconstruct(X) array Reconstruct data from archetype weights
get_loss_history() array Get the loss history from training
get_all_archetypes() tuple Get both sets of archetypes (BiarchetypalAnalysis only)
get_all_weights() tuple Get both sets of weights (BiarchetypalAnalysis only)

Examples

Visualizing Archetypes in 2D Data

import numpy as np
import matplotlib.pyplot as plt
from archetypax import ImprovedArchetypalAnalysis
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer

# Generate some interesting 2D data (a triangle with points inside)
n_samples = 500
vertices = np.array([[0, 0], [1, 0], [0.5, 0.866]])
weights = np.random.dirichlet(np.ones(3), size=n_samples)
X = weights @ vertices

# Fit archetypal analysis with 3 archetypes
model = ImprovedArchetypalAnalysis(n_archetypes=3, archetype_init_method="directional")
model.fit(X)

# Plot original data and archetypes
plt.figure(figsize=(10, 8))
ArchetypalAnalysisVisualizer.plot_archetypes_2d(model, X)
plt.title("Archetypal Analysis of 2D Data")
plt.show()

Using Biarchetypal Analysis

import numpy as np
import matplotlib.pyplot as plt
from archetypax import BiarchetypalAnalysis
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer

# Generate synthetic data
np.random.seed(42)
X = np.random.rand(500, 5)

# Initialize and fit the model with row and column archetypes
model = BiarchetypalAnalysis(
    n_row_archetypes=2,   # Number of archetypes in observation space
    n_col_archetypes=2,   # Number of archetypes in feature space
    max_iter=500,
    random_seed=42
)
model.fit(X)

# Get both sets of archetypes
row_archetypes, col_archetypes = model.get_all_archetypes()
print("Row archetypes shape:", row_archetypes.shape)
print("Column archetypes shape:", col_archetypes.shape)

# Get both sets of weights
row_weights, col_weights = model.get_all_weights()
print("Row weights shape:", row_weights.shape)
print("Column weights shape:", col_weights.shape)

# Reconstruct data using biarchetypes
X_reconstructed = model.reconstruct()
mse = np.mean((X - X_reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.6f}")

Tracking Archetype Evolution

import numpy as np
import matplotlib.pyplot as plt
from archetypax import ArchetypeTracker

# Generate sample data
np.random.seed(42)
X = np.random.rand(1000, 10)

# Initialize the tracker
tracker = ArchetypeTracker(
    n_archetypes=3,
    max_iter=300,
    random_seed=42
)

# Fit the model while tracking archetype movement
tracker.fit(X)

# Visualize the archetype movement trajectory
tracker.visualize_movement()

# Visualize boundary proximity over iterations
tracker.visualize_boundary_proximity()

How It Works

Archetypal Analysis solves the following optimization problem:

Given a data matrix $\mathbf{X} \in \mathbb{R}^{n \times d}$ with n samples and d features, find k archetypes $\mathbf{A} \in \mathbb{R}^{k \times d}$ and weights $\mathbf{W} \in \mathbb{R}^{n \times k}$ such that:

$$ \text{minimize} \ | \mathbf{X} - \mathbf{W} \cdot \mathbf{A} |^2_{\text{F}} $$

subject to:

  • $\mathbf{W}$ is non-negative
  • Each row of $\mathbf{W}$ sums to 1 (simplex constraint)
  • $\mathbf{A}$ lies within the convex hull of $\mathbf{X}$

The biarchetypal extension solves a more complex factorization:

$$ \mathbf{X} \approx \mathbf{\alpha} \cdot \mathbf{\beta} \cdot \mathbf{X} \cdot \mathbf{\theta} \cdot \mathbf{\gamma} $$

This implementation uses JAX's automatic differentiation and optimization tools to efficiently solve these problems on GPUs. It also incorporates several advanced enhancements:

  1. Strategic initialization methods including directional initialization, k-means++ style, and convex hull approximation
  2. Intelligent regularization techniques to promote interpretable weight distributions
  3. Advanced projection methods including adaptive convex boundary approximation (CBAP)
  4. Sophisticated numerical stability safeguards throughout the optimization process
  5. Comprehensive trajectory tracking for monitoring convergence dynamics

Contributing

Contributions are welcome and highly encouraged! Before submitting a pull request, please review the following resources:

To contribute to the project:

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Commit your changes (git commit -m 'Add some amazing feature')
  4. Push to the branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

Community

Citation

If you use this package in your research, please cite:

@software{archetypax2025,
  author = {mary},
  title = {archetypax: GPU-accelerated Archetypal Analysis using JAX},
  year = {2025},
  url = {https://github.com/lv416e/archetypax}
}

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

archetypax-0.1.1.tar.gz (105.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

archetypax-0.1.1-py3-none-any.whl (101.3 kB view details)

Uploaded Python 3

File details

Details for the file archetypax-0.1.1.tar.gz.

File metadata

  • Download URL: archetypax-0.1.1.tar.gz
  • Upload date:
  • Size: 105.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for archetypax-0.1.1.tar.gz
Algorithm Hash digest
SHA256 b75e325880deb8911e16c5a21d675c2c698573e062971ac5bb471de1e7331f12
MD5 6fca311acd3a231b877facc5f1c7b047
BLAKE2b-256 299f004c4861dcd4bac3844a5a640c6f3b211a2c1edaaa8abac505fa4e0efd77

See more details on using hashes here.

Provenance

The following attestation bundles were made for archetypax-0.1.1.tar.gz:

Publisher: release.yml on lv416e/archetypax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file archetypax-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: archetypax-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 101.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for archetypax-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b4318c316ae98385ee10356afa2b5e75e4c5a818bac7ba7bbd5e3afc86283ac2
MD5 8a22a5495d31b2f9c3a358eb9f0c3382
BLAKE2b-256 1cbc7e8cb64e2c435862886e203a059de2777e1c9dbd9b8577826f9dd7a33435

See more details on using hashes here.

Provenance

The following attestation bundles were made for archetypax-0.1.1-py3-none-any.whl:

Publisher: release.yml on lv416e/archetypax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page