Skip to main content

DIMA: Diffusion–Intrinsic Manifold Autoencoder: DMAP, GPLM, DIMA

Project description

DIMA — Diffusion–Intrinsic Manifold Autoencoder

DIMA combines three components into a practical, scalable manifold autoencoder:

  • DMAP: Diffusion Maps encoder (builds a kNN graph in ambient space and embeds points into diffusion coordinates).
  • GPLM: Nyström / inducing-point kernel ridge decoder (maps latent diffusion coordinates back to ambient space).
  • DDPM: Latent diffusion model (optional) that learns a generative prior over normalized latents and can “refine” latents.

Philosophy: keep DMAP + GPLM fast on CPU/RAM (NumPy/SciPy, sparse ops), and optionally run DDPM on JAX (CPU or GPU).


Installation

CPU-only (recommended to start)

pip install dima

Optional extras

  • Faster kNN search (FAISS):
pip install dima[faiss]
  • Hugging Face upload/load:
pip install dima[hf]

GPU note (JAX): pip install dima installs CPU jaxlib by default. For GPU, install the correct JAX wheel for your CUDA/ROCm setup first (per JAX docs), then install dima.


Quickstart

import numpy as np
import jax
from dima import DIMA

R_iX = np.random.randn(5000, 16).astype(np.float32)

dima = DIMA(R_iX)                 # trains DMAP, GPLM, DDPM with defaults

Z = dima.encode(R_iX[:10])        # ambient -> normalized latent (jnp)
X_hat = dima.decode(Z, refine=False)  # latent -> ambient (np), no DDPM
X_ref = dima.decode(Z, refine=True, t_start=10)  # DDPM refinement

X_gen = dima.sample(1000)         # unconditional samples -> ambient (np)

What DIMA trains

1) DMAP encoder (Diffusion Maps)

DMAP builds a kNN graph over ambient data $R_{iX}\in\mathbb{R}^{N\times D}$ and returns diffusion coordinates $R_{ix}\in\mathbb{R}^{N\times d}$.

Key idea (no theory): DMAP produces a geometry-aware latent space where nearby points on the manifold remain nearby in diffusion distance.

Main knobs (DMAP):

  • d (int): latent dimension.
  • k (int): neighbors in the kNN graph. If None, a heuristic is used.
  • beta / β (float): kernel sharpness for the RBF affinity $K_{ij}=\exp{-\beta |x_i-x_j|^2/\varepsilon}$.
  • eps / ε (float or None): kernel bandwidth. If None, estimated from kNN distances (median heuristic).
  • alpha / α (float): density normalization exponent. Common values: 0.0 (none) or 1.0 (often robust).
  • t (float): diffusion time exponent (scales eigenvalues as $\lambda^t$). Often 0.5 or 1.0.
  • drop_trivial (bool): drops the top eigenvector/eigenvalue (the constant mode).
  • sym (str): symmetrization mode for sparse kNN kernel graph: "max" or "mean".
  • ann_backend (str): "auto" | "faiss" | "pynndescent" | "sklearn" | "brute".

Typical DMAP presets

  • Fast-ish and stable: k=128..512, alpha=1.0, t=0.5, drop_trivial=True.
  • If your data is very noisy, try larger k and/or larger eps_mul.

2) GPLM decoder (Nyström kernel ridge / inducing GP)

GPLM learns a mapping from latents back to ambient:

  • Inputs: latent training points $R_{ix}\in\mathbb{R}^{N\times d}$
  • Targets: ambient training points $R_{iX}\in\mathbb{R}^{N\times D}$

Instead of a full $N\times N$ kernel solve, GPLM uses inducing points $Z_{mx}$ with $m\ll N$ and solves a reduced system:

  • Build affinities $C_{im}=\exp{-\beta|R_{ix}-Z_{mx}|^2/\varepsilon}$
  • Solve a stabilized kernel ridge / GP mean system to obtain weights $M_{mX}$
  • Predict: $\hat R_{aX}=C_{am}M_{mX}+\mu_X$

Main knobs (GPLM):

  • m (int): number of inducing points. Bigger → better accuracy, more compute.

  • inducing (str): inducing strategy:

    • "kmeans_medoids" (default): kmeans centers snapped to nearest training latent.
    • "fps": farthest-point sampling (space-filling).
    • "random_subset": fastest.
    • "given": use provided Z_mx.
  • sigma2 / σ2 (float): ridge regularization. Too small can overfit / cause instability; too large blurs reconstructions.

  • jitter (float): tiny diagonal stabilizer for Cholesky.

  • eps / ε (float or None): RBF bandwidth in latent space. If None, estimated from latent kNN distances.

  • k_eps / κ_eps (int): neighbors used for the $\varepsilon$ heuristic.

  • pred_k / pred_κ (int or None): prediction-time inducing neighbors:

    • None means use all inducing points (best accuracy).
    • a small number (e.g. 128 or 256) speeds inference (slightly lower accuracy).
  • whiten_latent (bool): optionally standardize latent dimensions before kernel computation.

  • center_X (bool): subtract and re-add ambient mean (usually helpful).

  • fit_block (int): block size for streaming $C^T C$ accumulation (memory/perf knob).

  • ann_backend (str): same options as DMAP.

Typical GPLM presets

  • Accurate: m=1024..4096, pred_k=None, sigma2=1e-5 (tune).
  • Faster inference: set pred_k=128..512.

3) DDPM latent diffusion (optional generative prior)

DDPM learns a distribution over normalized latents: [ Z = \frac{R_x - \mu}{\sigma} ] and can:

  • sample new latents,
  • refine a given latent by projecting it onto the learned latent manifold/prior.

In DIMA, DDPM operates purely in latent space (dimension d), so it’s lightweight compared to image DDPMs.

Main knobs (DDPM):

  • T (int): number of diffusion steps. Common: 100..1000. DIMA default: 200.
  • hidden_dim (int): MLP width.
  • t_embed_dim (int): time embedding dimension.
  • n_iter (int): training iterations (more is better for sampling quality).
  • batch_size (int): training batch size.
  • learning_rate (float): Adam learning rate.
  • ema_decay (float): EMA smoothing for stable sampling (typical: 0.999).
  • beta_max (float): caps noise schedule (stability knob).
  • eps (float): numerical stabilizer.
  • verbose_every (int): progress prints.

Refinement knobs (during decoding):

  • refine (bool): enable/disable DDPM refinement.

  • t_start (int): how strongly to “project” using reverse diffusion.

    • small (1..10) = gentle projection
    • larger (20..100) = stronger projection (can oversmooth or drift if DDPM undertrained)
  • add_noise (bool): whether to forward-noise before reverse steps.


API

Core class

from dima import DIMA
dima = DIMA(R_iX)

Encode / decode

Z = dima.encode(X)                      # (B,d) normalized latent (jnp)
X_hat = dima.decode(Z, refine=False)    # (B,D) reconstruction (np)
X_ref = dima.decode(Z, refine=True, t_start=10)  # refined decode

Polymorphic call

Z = dima(X)     # if X.shape[-1] == D
X = dima(Z)     # if Z.shape[-1] == d

Sampling

X_gen = dima.sample(1000)            # ambient samples (np)
Z_gen = dima.sample(1000, decode=False)  # latent samples (jnp)

Configuration patterns

1) Use defaults (cleanest)

dima = DIMA(R_iX)

2) Pass only a few knobs

dima = DIMA(
    R_iX,
    d=64,
    dmap_kwargs=dict(k=256, alpha=1.0, t=0.5),
    gplm_kwargs=dict(m=2048, sigma2=1e-5, pred_k=256),
    ddpm_kwargs=dict(n_iter=50_000, hidden_dim=256),
)

Saving / Loading

Local

dima.save_local("dima.msgpack", "config.json")
dima2 = DIMA.load_local("dima.msgpack", ddpm_device="auto")

Hugging Face (optional)

dima.upload_to_huggingface(repo_id="username/dima-model", hf_token="...")

dima3 = DIMA.load_from_huggingface("username/dima-model", ddpm_device="auto")

Performance notes

  • DMAP: main costs are kNN search + sparse eigensolve.

    • kNN is faster with FAISS.
    • Increasing k increases graph density and compute.
  • GPLM: training cost roughly scales with $N\cdot m$ (streamed by fit_block).

    • Inference cost is $B\cdot m$ if pred_k=None, or $B\cdot \text{pred_k}$ if using inducing kNN.
  • DDPM: scales with latent dimension d, steps T, and training iterations.

    • GPU helps but CPU works for smaller d and fewer iterations.

Troubleshooting

“Unknown backend: gpu”

Your JAX install only has CPU. Use:

dima = DIMA(R_iX, ddpm_device="cpu")

or install a GPU-enabled JAX build.

Reconstructions are blurry / low quality

  • Increase gplm_kwargs["m"]
  • Decrease gplm_kwargs["sigma2"] slightly (careful: too small can destabilize)
  • Set pred_k=None (use all inducing points)

DDPM refinement makes reconstructions worse

  • Reduce t_start (try 3..10)
  • Increase DDPM training (ddpm_kwargs["n_iter"])
  • Disable add_noise for gentler behavior

Citations

  • Diffusion Maps / DMAE inspiration: Diffusion Map AutoEncoder (DMAE).
  • Latent diffusion prior: Denoising Diffusion Probabilistic Models (DDPM).

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dima-0.1.0.tar.gz (28.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dima-0.1.0-py3-none-any.whl (28.4 kB view details)

Uploaded Python 3

File details

Details for the file dima-0.1.0.tar.gz.

File metadata

  • Download URL: dima-0.1.0.tar.gz
  • Upload date:
  • Size: 28.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for dima-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d253d6264665cf7d61e6add197f83d4fc14fa2cbe8c8cc026aa68b75b65c56d0
MD5 2c480b810b3530ba92c5c8b86c27854a
BLAKE2b-256 e9ab7d35944a6f534be821e053c13161b084badc608a5b63fd37734d590375d0

See more details on using hashes here.

File details

Details for the file dima-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: dima-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 28.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for dima-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 388f9f5b5baadcf693a36c925168c96f32043f8d3af71f7e365e9c30dd35241b
MD5 441b58212c579c05e6a2a47b41108488
BLAKE2b-256 2154afb157b26affd9f1c4e4967b115be5564f8626f8737e9da1c665776efa8d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page