Skip to main content

dpgmm is a library implementing high-performance MCMC sampler for Dirichlet Process Gaussian Mixture Model (DPGMM). Built with PyTorch and accelerated with Triton kernels, it is designed to handle high-dimensional data efficiently.

Project description

dpgmm

dpgmm is a library implementing high-performance MCMC sampler for Dirichlet Process Gaussian Mixture Models (DPGMM). Built on PyTorch and accelerated with Triton kernels, it is designed to handle high-dimensional data efficiently.

Key Features

High Performance: Optimized Gibbs sampling leveraging GPU acceleration via PyTorch and Triton kernels.

Data Generation: Built-in utilities to generate high-dimensional synthetic datasets for validation.

Observability: Native integration with Weights & Biases for experiment tracking.

Metrics: Comprehensive tools for calculating assignment log likelihood, data complexity, and data dimensions entanglement.

Modern Stack: Developed using modern Python tooling (uv, ruff, pytest).

Installation

The package will be available on PyPI soon.

# Coming soon
pip install gmm_sampler

For now, you can clone the repository and set up the environment using uv or pip.

Usage

1. Generating data and running the sampler

Initialize the generator and the Gibbs sampler.

import torch
from dpgmm.datasets import GaussianDataGenerator
from dpgmm.samplers import FullCovarianceCollapsedGibbsSampler, DiagCovarianceCollapsedGibbsSampler

# 1. Generate synthetic data
data_generator = GaussianDataGenerator(cov_type="full")
data_payload = data_generator.generate(n_points=256, data_dim=2, num_components=4)
data_tensor = torch.as_tensor(data_payload["data"])

# 2. Initialize the Sampler
sampler = FullCovarianceCollapsedGibbsSampler(
    init_strategy="init_data_stats",
    max_clusters_num=10,
    batch_size=1
)

# 3. Fit the model
result = sampler.fit(iterations_num=100, data=data_tensor)

# Access results
cluster_params = result["cluster_params"]
cluster_assignment = result["cluster_assignment"]
alpha = result["alpha"]

2. Visualizing results

Visualize the clusters, covariance matrices, and assignments.

from dpgmm.visualisation import ClusterParamsVisualizer

data_visualizer = ClusterParamsVisualizer()

data_visualizer.plot_params_full_covariance(
    data_payload["data"],
    centers=cluster_params["mean"],
    cov_chol=cluster_params["cov_chol"],
    assignment=cluster_assignment,
    trace_alpha=alpha,
)

3. Checkpointing

You can save checkpoints during training and resume from them later.

# To save during training, specify an out_dir
sampler.fit(iterations_num=25, data=data_tensor, out_dir="out/save_and_load")

# To resume, pass the path to the snapshot directory in kwargs
additional_kwargs = {"restore_snapshot_pkl_path": "out/save_and_load/"}

sampler_restored = FullCovarianceCollapsedGibbsSampler(
    init_strategy="init_data_stats",
    max_clusters_num=10,
    batch_size=1,
    **additional_kwargs,
)

4. Calculating metrics

Data complexity

Estimate entropy from sampling versus data to gauge model fit.

from dpgmm.metrics import ComplexityFromTraceEstimator

estimator = ComplexityFromTraceEstimator(
    trace_path="/path/to/results/cgs_19.pkl",
    data_trace_path="/path/to/results/cgs_0.pkl",
    samples_num=100_000,
)

entropy_sampled = estimator.estimate_entropy_with_sampling()
entropy_data = estimator.estimate_entropy_on_data(data_tensor)

print(f"Entropy from sampling: {entropy_sampled}")
print(f"Entropy on data: {entropy_data}")

Data dimensions entanglement

Calculate the KL divergence between joint and product marginals to measure feature entanglement.

from dpgmm.metrics import EntanglementFromTraceEstimator

estimator = EntanglementFromTraceEstimator(
    trace_path="/path/to/results/cgs_99.pkl",
    samples_num=100_000
)

dkl_joint_prod = estimator.calculate_joint_and_prod_dkl()
dkl_symmetric = estimator.calculate_symmetric_dkl()

print(f"KL(Joint || Marginals Prod):  {dkl_joint_prod:.4f}")

Integrations & observability

The sampler supports W&B out of the box for tracking loss curves, cluster evolution, and system metrics. To enable experiment tracking, just make sure to export WAND_API_KEY environment variable.

export WANDB_API_KEY=your_key_here

Benchmarks

Thanks to Triton kernels, dpgmm achieves significant speedups compared to standard implementations, especially in high-dimensional experiments. The following table showcases the average iteration time (in seconds) for $N=1000$ points using the full covariance model.

Data dim PyTorch CPU [s] Optimized GPU [s] Speedup
128 $2.893 \pm 0.702$ $0.656 \pm 0.004$ $\times 4.41$
256 $6.300 \pm 1.237$ $0.520 \pm 0.027$ $\times 12.11$
512 $23.857 \pm 1.843$ $0.456 \pm 0.018$ $\times 52.36$
1024 $53.204 \pm 2.394$ $1.196 \pm 0.024$ $\times 44.47$
2048 $140.795 \pm 4.535$ $3.141 \pm 0.138$ $\times 44.83$
4096 $494.269 \pm 6.651$ $13.414 \pm 0.611$ $\times 36.85$
8192 $3479.447 \pm 72.901$ $37.880 \pm 1.458$ $\times 91.85$

Development

This project uses uv for dependency management and Task (go-task) for orchestrating development workflows.

# Install dependencies
uv sync

# Install pre-commit hooks
uv run pre-commit install

Task Automation

A Taskfile.yml is provided to simplify common development command - use the following commands:

# Run linter and formatter (Ruff)
uv run task lint

# Run security audits (Bandit & Safety)
uv run task audit

# Check code complexity (Xenon)
uv run task complexity

# Run all quality and security checks
uv run task check-all

# Build documentation
uv run task build-docs

# Run all tests
uv run task run-tests

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dpgmm-0.1.0.tar.gz (46.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dpgmm-0.1.0-py3-none-any.whl (78.4 kB view details)

Uploaded Python 3

File details

Details for the file dpgmm-0.1.0.tar.gz.

File metadata

  • Download URL: dpgmm-0.1.0.tar.gz
  • Upload date:
  • Size: 46.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.3 {"installer":{"name":"uv","version":"0.10.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for dpgmm-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f1c0695ffe32ebece1d43cfe81c4c981d822451118e40e003a51e862154d5bca
MD5 7a24f6b9e586df630deee430e4dc7878
BLAKE2b-256 b8cec4b7f6ad8924a09251c5f8cfb2d1c474e5db3541fcb3dc5ad39594888d2b

See more details on using hashes here.

File details

Details for the file dpgmm-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: dpgmm-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 78.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.3 {"installer":{"name":"uv","version":"0.10.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for dpgmm-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c37f80c5a768158b9f4a1360f1c846f94edd029e6f8e72bd4b5d7e50a5b623cc
MD5 5e5bb1e04bd36cb9d68d4ae18199a39f
BLAKE2b-256 928d2315efd4c505dcddf05d213608acff3602733fff9094fdc83c8b2328c69f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page