GPU-accelerated Archetypal Analysis implementation using JAX
Project description
ArchetypAX – Hardware-accelerated Archetypal Analysis with JAX
Discover extreme patterns in your data with GPU/TPU-accelerated Archetypal Analysis, high-performance convex hull optimization, and interpretable matrix factorization.
Table of Contents
- Overview
- Features
- Installation
- Quick Start
- Import Patterns
- Documentation
- Examples
- How It Works
- Changelog
- Citation
- License
- Contributing
- Community
Overview
archetypax is a high-performance implementation of Archetypal Analysis (AA) leveraging JAX for GPU acceleration.
Archetypal Analysis is a powerful matrix factorization technique representing data points
as convex combinations of extreme points (archetypes) found within the data's convex hull.
Unlike traditional dimensionality reduction techniques like PCA, which finds abstract orthogonal components,
AA discovers interpretable extremal points often corresponding to meaningful prototypes.
This makes it valuable for applications requiring both dimensionality reduction and human-interpretable insights,
such as market segmentation, document analysis, and anomaly detection.
Features
Performance & Stability
- 🚀 GPU/TPU acceleration using JAX
- 🧠 Smart initialization (k-means++, directional)
- 🛠️ Numerical stability & convergence techniques
Usability & Compatibility
- 📊 scikit-learn compatible API (fit/transform)
- 📋 Thorough documentation
Interpretability & Visualization
- 🔍 Meaningful interpretable archetypes
- 📈 Advanced tracking & optimization trajectory monitoring
- 🎯 Comprehensive evaluation & visualization tooling
Related Projects and Techniques
ArchetypAX can be used alongside or compared with these related approaches:
- PCA: Principal Component Analysis finds orthogonal directions of maximum variance
- NMF: Non-negative Matrix Factorization decomposes data into non-negative components
- k-means: Clustering technique that partitions data into k clusters
- JAX Ecosystem: Compatible with JAX-based machine learning frameworks like Flax
- scikit-learn: Follows similar API conventions, allowing easy integration
Installation
Install with pip, uv, or poetry:
# pip
pip install archetypax
pip install git+https://github.com/lv416e/archetypax.git
# uv
uv pip install archetypax
uv pip install git+https://github.com/lv416e/archetypax.git
# poetry
poetry add archetypax
poetry add git+https://github.com/lv416e/archetypax.git
Install optional dependencies:
pip install archetypax[dev] # Development dependencies
pip install archetypax[examples] # Example dependencies
pip install archetypax[docs] # Documentation dependencies
Requirements
| Type | Dependency | Version | Description |
|---|---|---|---|
| Core | Python | >=3.10 | Required for modern language features and compatibility with JAX |
| Core | JAX | >=0.4.0 | Powers the hardware acceleration and automatic differentiation |
| Core | NumPy | >=1.20.0 | Handles core numerical operations and array manipulations |
| Core | optax | >=0.1.0 | JAX-based optimization framework for gradient-based updates |
| Core | pandas | >=1.3.0 | Data manipulation and analysis library |
| Core | scikit-learn | >=1.0.0 | Provides machine learning utilities and compatible interfaces |
| Examples | jupyter | >=1.0.0 | Interactive computing environment for notebooks |
| Examples | matplotlib | >=3.7.5 | Required for visualization functionality |
| Examples | seaborn | >=0.13.2 | Statistical data visualization |
| Dev | black | ==23.7.0 | Code formatter |
| Dev | mypy | >=1.8.0 | Static type checker |
| Dev | pytest | >=7.0.0 | Testing framework |
| Dev | ruff | >=0.9.0 | Fast Python linter and formatter |
Quick Start
import numpy as np
from archetypax import ImprovedArchetypalAnalysis as ArchetypalAnalysis
# Generate sample data
np.random.seed(42)
X = np.random.rand(1000, 10)
# Initialize and fit the model
model = ArchetypalAnalysis(n_archetypes=5)
weights = model.fit_transform(X)
# Get the archetypes
archetypes = model.archetypes
# Reconstruct the data
X_reconstructed = model.reconstruct()
# Calculate reconstruction error
mse = np.mean((X - X_reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.6f}")
Import Patterns
ArchetypAX supports multiple import patterns for flexibility:
Direct Class Imports (Recommended)
from archetypax import ArchetypalAnalysis, ImprovedArchetypalAnalysis, BiarchetypalAnalysis, ArchetypeTracker
Explicit Module Imports
from archetypax.models.base import ArchetypalAnalysis
from archetypax.models.biarchetypes import BiarchetypalAnalysis
from archetypax.tools.evaluation import ArchetypalAnalysisEvaluator
from archetypax.tools.tracker import ArchetypeTracker
Module-Level Imports
from archetypax.models import ArchetypalAnalysis
from archetypax.tools import ArchetypalAnalysisVisualizer, ArchetypeTracker
Changelog
For a detailed list of changes and version history, please see the CHANGELOG.md file.
Documentation
Parameters
ArchetypalAnalysis / ImprovedArchetypalAnalysis
| Parameter | Type | Default | Description |
|---|---|---|---|
n_archetypes |
int | - | Number of archetypes to find |
max_iter |
int | 500 | Maximum number of iterations |
tol |
float | 1e-6 | Convergence tolerance |
random_seed |
int | 42 | Random seed for initialization |
learning_rate |
float | 0.001 | Learning rate for optimizer |
lambda_reg |
float | 0.01 | Regularization strength for weight distribution |
normalize |
bool | False | Whether to normalize features before fitting |
projection_method |
str | "cbap" | Method for projecting archetypes ("cbap", "convex_hull", "knn") |
projection_alpha |
float | 0.1 | Blending coefficient for boundary projection |
archetype_init_method |
str | "directional" | Initialization strategy ("directional", "kmeans++", "qhull") |
BiarchetypalAnalysis
| Parameter | Type | Default | Description |
|---|---|---|---|
n_row_archetypes |
int | - | Number of archetypes in observation space |
n_col_archetypes |
int | - | Number of archetypes in feature space |
max_iter |
int | 500 | Maximum number of iterations |
tol |
float | 1e-6 | Convergence tolerance |
random_seed |
int | 42 | Random seed for initialization |
learning_rate |
float | 0.001 | Learning rate for optimizer |
projection_method |
str | "default" | Method for projecting archetypes |
lambda_reg |
float | 0.01 | Regularization strength for entropy terms |
Methods
| Method | Returns | Description |
|---|---|---|
fit(X) |
model | Fit the model to the data |
transform(X) |
array | Transform new data to archetype weights |
fit_transform(X) |
array | Fit the model and transform the data |
reconstruct(X) |
array | Reconstruct data from archetype weights |
get_loss_history() |
array | Get the loss history from training |
get_all_archetypes() |
tuple | Get both sets of archetypes (BiarchetypalAnalysis only) |
get_all_weights() |
tuple | Get both sets of weights (BiarchetypalAnalysis only) |
Examples
Visualizing Archetypes in 2D Data
import numpy as np
import matplotlib.pyplot as plt
from archetypax import ImprovedArchetypalAnalysis
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer
# Generate some interesting 2D data (a triangle with points inside)
n_samples = 500
vertices = np.array([[0, 0], [1, 0], [0.5, 0.866]])
weights = np.random.dirichlet(np.ones(3), size=n_samples)
X = weights @ vertices
# Fit archetypal analysis with 3 archetypes
model = ImprovedArchetypalAnalysis(n_archetypes=3, archetype_init_method="directional")
model.fit(X)
# Plot original data and archetypes
plt.figure(figsize=(10, 8))
ArchetypalAnalysisVisualizer.plot_archetypes_2d(model, X)
plt.title("Archetypal Analysis of 2D Data")
plt.show()
Using Biarchetypal Analysis
import numpy as np
import matplotlib.pyplot as plt
from archetypax import BiarchetypalAnalysis
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer
# Generate synthetic data
np.random.seed(42)
X = np.random.rand(500, 5)
# Initialize and fit the model with row and column archetypes
model = BiarchetypalAnalysis(
n_row_archetypes=2, # Number of archetypes in observation space
n_col_archetypes=2, # Number of archetypes in feature space
max_iter=500,
random_seed=42
)
model.fit(X)
# Get both sets of archetypes
row_archetypes, col_archetypes = model.get_all_archetypes()
print("Row archetypes shape:", row_archetypes.shape)
print("Column archetypes shape:", col_archetypes.shape)
# Get both sets of weights
row_weights, col_weights = model.get_all_weights()
print("Row weights shape:", row_weights.shape)
print("Column weights shape:", col_weights.shape)
# Reconstruct data using biarchetypes
X_reconstructed = model.reconstruct()
mse = np.mean((X - X_reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.6f}")
Tracking Archetype Evolution
import numpy as np
import matplotlib.pyplot as plt
from archetypax import ArchetypeTracker
# Generate sample data
np.random.seed(42)
X = np.random.rand(1000, 10)
# Initialize the tracker
tracker = ArchetypeTracker(
n_archetypes=3,
max_iter=300,
random_seed=42
)
# Fit the model while tracking archetype movement
tracker.fit(X)
# Visualize the archetype movement trajectory
tracker.visualize_movement()
# Visualize boundary proximity over iterations
tracker.visualize_boundary_proximity()
How It Works
Archetypal Analysis solves the following optimization problem:
Given a data matrix $\mathbf{X} \in \mathbb{R}^{n \times d}$ with n samples and d features, find k archetypes $\mathbf{A} \in \mathbb{R}^{k \times d}$ and weights $\mathbf{W} \in \mathbb{R}^{n \times k}$ such that:
$$ \text{minimize} \ | \mathbf{X} - \mathbf{W} \cdot \mathbf{A} |^2_{\text{F}} $$
subject to:
- $\mathbf{W}$ is non-negative
- Each row of $\mathbf{W}$ sums to 1 (simplex constraint)
- $\mathbf{A}$ lies within the convex hull of $\mathbf{X}$
The biarchetypal extension solves a more complex factorization:
$$ \mathbf{X} \approx \mathbf{\alpha} \cdot \mathbf{\beta} \cdot \mathbf{X} \cdot \mathbf{\theta} \cdot \mathbf{\gamma} $$
This implementation uses JAX's automatic differentiation and optimization tools to efficiently solve these problems on GPUs. It also incorporates several advanced enhancements:
- Strategic initialization methods including directional initialization, k-means++ style, and convex hull approximation
- Intelligent regularization techniques to promote interpretable weight distributions
- Advanced projection methods including adaptive convex boundary approximation (CBAP)
- Sophisticated numerical stability safeguards throughout the optimization process
- Comprehensive trajectory tracking for monitoring convergence dynamics
Contributing
Contributions are welcome and highly encouraged! Before submitting a pull request, please review the following resources:
- Code of Conduct: Guidelines for community participation
- Security Policy: Vulnerability reporting and handling procedures
To contribute to the project:
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add some amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
Community
- 🐞 Issue Tracker: Report bugs and request features
- 💬 Discussions: Questions and general community interactions
Citation
If you use this package in your research, please cite:
@software{archetypax2025,
author = {mary},
title = {archetypax: GPU-accelerated Archetypal Analysis using JAX},
year = {2025},
url = {https://github.com/lv416e/archetypax}
}
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file archetypax-0.1.1.tar.gz.
File metadata
- Download URL: archetypax-0.1.1.tar.gz
- Upload date:
- Size: 105.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b75e325880deb8911e16c5a21d675c2c698573e062971ac5bb471de1e7331f12
|
|
| MD5 |
6fca311acd3a231b877facc5f1c7b047
|
|
| BLAKE2b-256 |
299f004c4861dcd4bac3844a5a640c6f3b211a2c1edaaa8abac505fa4e0efd77
|
Provenance
The following attestation bundles were made for archetypax-0.1.1.tar.gz:
Publisher:
release.yml on lv416e/archetypax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
archetypax-0.1.1.tar.gz -
Subject digest:
b75e325880deb8911e16c5a21d675c2c698573e062971ac5bb471de1e7331f12 - Sigstore transparency entry: 190788770
- Sigstore integration time:
-
Permalink:
lv416e/archetypax@42145594cc1d060c0b33a94cf639c4ad426bb09d -
Branch / Tag:
refs/tags/v0.1.1 - Owner: https://github.com/lv416e
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@42145594cc1d060c0b33a94cf639c4ad426bb09d -
Trigger Event:
push
-
Statement type:
File details
Details for the file archetypax-0.1.1-py3-none-any.whl.
File metadata
- Download URL: archetypax-0.1.1-py3-none-any.whl
- Upload date:
- Size: 101.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b4318c316ae98385ee10356afa2b5e75e4c5a818bac7ba7bbd5e3afc86283ac2
|
|
| MD5 |
8a22a5495d31b2f9c3a358eb9f0c3382
|
|
| BLAKE2b-256 |
1cbc7e8cb64e2c435862886e203a059de2777e1c9dbd9b8577826f9dd7a33435
|
Provenance
The following attestation bundles were made for archetypax-0.1.1-py3-none-any.whl:
Publisher:
release.yml on lv416e/archetypax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
archetypax-0.1.1-py3-none-any.whl -
Subject digest:
b4318c316ae98385ee10356afa2b5e75e4c5a818bac7ba7bbd5e3afc86283ac2 - Sigstore transparency entry: 190788773
- Sigstore integration time:
-
Permalink:
lv416e/archetypax@42145594cc1d060c0b33a94cf639c4ad426bb09d -
Branch / Tag:
refs/tags/v0.1.1 - Owner: https://github.com/lv416e
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@42145594cc1d060c0b33a94cf639c4ad426bb09d -
Trigger Event:
push
-
Statement type: