dpgmm is a library implementing high-performance MCMC sampler for Dirichlet Process Gaussian Mixture Model (DPGMM). Built with PyTorch and accelerated with Triton kernels, it is designed to handle high-dimensional data efficiently.
Project description
dpgmm
dpgmm is a library implementing high-performance MCMC sampler for Dirichlet Process Gaussian Mixture Models (DPGMM). Built on PyTorch and accelerated with Triton kernels, it is designed to handle high-dimensional data efficiently.
Key Features
High Performance: Optimized Gibbs sampling leveraging GPU acceleration via PyTorch and Triton kernels.
Data Generation: Built-in utilities to generate high-dimensional synthetic datasets for validation.
Observability: Native integration with Weights & Biases for experiment tracking.
Metrics: Comprehensive tools for calculating assignment log likelihood, data complexity, and data dimensions entanglement.
Modern Stack: Developed using modern Python tooling (uv, ruff, pytest).
Installation
The package will be available on PyPI soon.
# Coming soon
pip install gmm_sampler
For now, you can clone the repository and set up the environment using uv or pip.
Usage
1. Generating data and running the sampler
Initialize the generator and the Gibbs sampler.
import torch
from dpgmm.datasets import GaussianDataGenerator
from dpgmm.samplers import FullCovarianceCollapsedGibbsSampler, DiagCovarianceCollapsedGibbsSampler
# 1. Generate synthetic data
data_generator = GaussianDataGenerator(cov_type="full")
data_payload = data_generator.generate(n_points=256, data_dim=2, num_components=4)
data_tensor = torch.as_tensor(data_payload["data"])
# 2. Initialize the Sampler
sampler = FullCovarianceCollapsedGibbsSampler(
init_strategy="init_data_stats",
max_clusters_num=10,
batch_size=1
)
# 3. Fit the model
result = sampler.fit(iterations_num=100, data=data_tensor)
# Access results
cluster_params = result["cluster_params"]
cluster_assignment = result["cluster_assignment"]
alpha = result["alpha"]
2. Visualizing results
Visualize the clusters, covariance matrices, and assignments.
from dpgmm.visualisation import ClusterParamsVisualizer
data_visualizer = ClusterParamsVisualizer()
data_visualizer.plot_params_full_covariance(
data_payload["data"],
centers=cluster_params["mean"],
cov_chol=cluster_params["cov_chol"],
assignment=cluster_assignment,
trace_alpha=alpha,
)
3. Checkpointing
You can save checkpoints during training and resume from them later.
# To save during training, specify an out_dir
sampler.fit(iterations_num=25, data=data_tensor, out_dir="out/save_and_load")
# To resume, pass the path to the snapshot directory in kwargs
additional_kwargs = {"restore_snapshot_pkl_path": "out/save_and_load/"}
sampler_restored = FullCovarianceCollapsedGibbsSampler(
init_strategy="init_data_stats",
max_clusters_num=10,
batch_size=1,
**additional_kwargs,
)
4. Calculating metrics
Data complexity
Estimate entropy from sampling versus data to gauge model fit.
from dpgmm.metrics import ComplexityFromTraceEstimator
estimator = ComplexityFromTraceEstimator(
trace_path="/path/to/results/cgs_19.pkl",
data_trace_path="/path/to/results/cgs_0.pkl",
samples_num=100_000,
)
entropy_sampled = estimator.estimate_entropy_with_sampling()
entropy_data = estimator.estimate_entropy_on_data(data_tensor)
print(f"Entropy from sampling: {entropy_sampled}")
print(f"Entropy on data: {entropy_data}")
Data dimensions entanglement
Calculate the KL divergence between joint and product marginals to measure feature entanglement.
from dpgmm.metrics import EntanglementFromTraceEstimator
estimator = EntanglementFromTraceEstimator(
trace_path="/path/to/results/cgs_99.pkl",
samples_num=100_000
)
dkl_joint_prod = estimator.calculate_joint_and_prod_dkl()
dkl_symmetric = estimator.calculate_symmetric_dkl()
print(f"KL(Joint || Marginals Prod): {dkl_joint_prod:.4f}")
Integrations & observability
The sampler supports W&B out of the box for tracking loss curves, cluster evolution, and system metrics. To enable experiment tracking, just make sure to export WAND_API_KEY environment variable.
export WANDB_API_KEY=your_key_here
Benchmarks
Thanks to Triton kernels, dpgmm achieves significant speedups compared to standard implementations, especially in high-dimensional experiments. The following table showcases the average iteration time (in seconds) for $N=1000$ points using the full covariance model.
| Data dim | PyTorch CPU [s] | Optimized GPU [s] | Speedup |
|---|---|---|---|
| 128 | $2.893 \pm 0.702$ | $0.656 \pm 0.004$ | $\times 4.41$ |
| 256 | $6.300 \pm 1.237$ | $0.520 \pm 0.027$ | $\times 12.11$ |
| 512 | $23.857 \pm 1.843$ | $0.456 \pm 0.018$ | $\times 52.36$ |
| 1024 | $53.204 \pm 2.394$ | $1.196 \pm 0.024$ | $\times 44.47$ |
| 2048 | $140.795 \pm 4.535$ | $3.141 \pm 0.138$ | $\times 44.83$ |
| 4096 | $494.269 \pm 6.651$ | $13.414 \pm 0.611$ | $\times 36.85$ |
| 8192 | $3479.447 \pm 72.901$ | $37.880 \pm 1.458$ | $\times 91.85$ |
Development
This project uses uv for dependency management and Task (go-task) for orchestrating development workflows.
# Install dependencies
uv sync
# Install pre-commit hooks
uv run pre-commit install
Task Automation
A Taskfile.yml is provided to simplify common development command - use the following commands:
# Run linter and formatter (Ruff)
uv run task lint
# Run security audits (Bandit & Safety)
uv run task audit
# Check code complexity (Xenon)
uv run task complexity
# Run all quality and security checks
uv run task check-all
# Build documentation
uv run task build-docs
# Run all tests
uv run task run-tests
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dpgmm-0.1.0.tar.gz.
File metadata
- Download URL: dpgmm-0.1.0.tar.gz
- Upload date:
- Size: 46.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.3 {"installer":{"name":"uv","version":"0.10.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f1c0695ffe32ebece1d43cfe81c4c981d822451118e40e003a51e862154d5bca
|
|
| MD5 |
7a24f6b9e586df630deee430e4dc7878
|
|
| BLAKE2b-256 |
b8cec4b7f6ad8924a09251c5f8cfb2d1c474e5db3541fcb3dc5ad39594888d2b
|
File details
Details for the file dpgmm-0.1.0-py3-none-any.whl.
File metadata
- Download URL: dpgmm-0.1.0-py3-none-any.whl
- Upload date:
- Size: 78.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.3 {"installer":{"name":"uv","version":"0.10.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c37f80c5a768158b9f4a1360f1c846f94edd029e6f8e72bd4b5d7e50a5b623cc
|
|
| MD5 |
5e5bb1e04bd36cb9d68d4ae18199a39f
|
|
| BLAKE2b-256 |
928d2315efd4c505dcddf05d213608acff3602733fff9094fdc83c8b2328c69f
|