Skip to main content

DIMA: Diffusion–Isocoder Manifold-Autoencoder

Project description

DIMA — Diffusion–Isocoder Manifold-Autoencoder

DIMA combines three components into a practical, scalable manifold autoencoder:

  • DMAP: Diffusion Maps encoder (builds a kNN graph in ambient space and embeds points into diffusion coordinates).
  • GPLM: Nyström / inducing-point kernel ridge decoder (maps latent diffusion coordinates back to ambient space).
  • DDPM: Latent diffusion model (optional) that learns a generative prior over normalized latents and can “refine” latents.

Philosophy: keep DMAP + GPLM fast on CPU/RAM (NumPy/SciPy, sparse ops), and optionally run DDPM on JAX (CPU or GPU).


Installation

CPU-only (recommended to start)

pip install dima

Optional extras

  • Faster kNN search (FAISS):
pip install dima[faiss]
  • Hugging Face upload/load:
pip install dima[hf]

GPU note (JAX): pip install dima installs CPU jaxlib by default. For GPU, install the correct JAX wheel for your CUDA/ROCm setup first (per JAX docs), then install dima.


Quickstart

import numpy as np
import jax
from dima import DIMA

R_iX = np.random.randn(5000, 16).astype(np.float32)

dima = DIMA(R_iX)                 # trains DMAP, GPLM, DDPM with defaults

Z = dima.encode(R_iX[:10])        # ambient -> normalized latent (jnp)
X_hat = dima.decode(Z, refine=False)  # latent -> ambient (np), no DDPM
X_ref = dima.decode(Z, refine=True, t_start=10)  # DDPM refinement

X_gen = dima.sample(1000)         # unconditional samples -> ambient (np)

What DIMA trains

1) DMAP encoder (Diffusion Maps)

DMAP builds a kNN graph over ambient data $R_{iX}\in\mathbb{R}^{N\times D}$ and returns diffusion coordinates $R_{ix}\in\mathbb{R}^{N\times d}$.

Key idea (no theory): DMAP produces a geometry-aware latent space where nearby points on the manifold remain nearby in diffusion distance.

Main knobs (DMAP):

  • d (int): latent dimension.
  • k (int): neighbors in the kNN graph. If None, a heuristic is used.
  • beta / β (float): kernel sharpness for the RBF affinity $K_{ij}=\exp{-\beta |x_i-x_j|^2/\varepsilon}$.
  • eps / ε (float or None): kernel bandwidth. If None, estimated from kNN distances (median heuristic).
  • alpha / α (float): density normalization exponent. Common values: 0.0 (none) or 1.0 (often robust).
  • t (float): diffusion time exponent (scales eigenvalues as $\lambda^t$). Often 0.5 or 1.0.
  • drop_trivial (bool): drops the top eigenvector/eigenvalue (the constant mode).
  • sym (str): symmetrization mode for sparse kNN kernel graph: "max" or "mean".
  • ann_backend (str): "auto" | "faiss" | "pynndescent" | "sklearn" | "brute".

Typical DMAP presets

  • Fast-ish and stable: k=128..512, alpha=1.0, t=0.5, drop_trivial=True.
  • If your data is very noisy, try larger k and/or larger eps_mul.

2) GPLM decoder (Nyström kernel ridge / inducing GP)

GPLM learns a mapping from latents back to ambient:

  • Inputs: latent training points $R_{ix}\in\mathbb{R}^{N\times d}$
  • Targets: ambient training points $R_{iX}\in\mathbb{R}^{N\times D}$

Instead of a full $N\times N$ kernel solve, GPLM uses inducing points $Z_{mx}$ with $m\ll N$ and solves a reduced system:

  • Build affinities $C_{im}=\exp{-\beta|R_{ix}-Z_{mx}|^2/\varepsilon}$
  • Solve a stabilized kernel ridge / GP mean system to obtain weights $M_{mX}$
  • Predict: $\hat R_{aX}=C_{am}M_{mX}+\mu_X$

Main knobs (GPLM):

  • m (int): number of inducing points. Bigger → better accuracy, more compute.

  • inducing (str): inducing strategy:

    • "kmeans_medoids" (default): kmeans centers snapped to nearest training latent.
    • "fps": farthest-point sampling (space-filling).
    • "random_subset": fastest.
    • "given": use provided Z_mx.
  • sigma2 / σ2 (float): ridge regularization. Too small can overfit / cause instability; too large blurs reconstructions.

  • jitter (float): tiny diagonal stabilizer for Cholesky.

  • eps / ε (float or None): RBF bandwidth in latent space. If None, estimated from latent kNN distances.

  • k_eps / κ_eps (int): neighbors used for the $\varepsilon$ heuristic.

  • pred_k / pred_κ (int or None): prediction-time inducing neighbors:

    • None means use all inducing points (best accuracy).
    • a small number (e.g. 128 or 256) speeds inference (slightly lower accuracy).
  • whiten_latent (bool): optionally standardize latent dimensions before kernel computation.

  • center_X (bool): subtract and re-add ambient mean (usually helpful).

  • fit_block (int): block size for streaming $C^T C$ accumulation (memory/perf knob).

  • ann_backend (str): same options as DMAP.

Typical GPLM presets

  • Accurate: m=1024..4096, pred_k=None, sigma2=1e-5 (tune).
  • Faster inference: set pred_k=128..512.

3) DDPM latent diffusion (optional generative prior)

DDPM learns a distribution over normalized latents: [ Z = \frac{R_x - \mu}{\sigma} ] and can:

  • sample new latents,
  • refine a given latent by projecting it onto the learned latent manifold/prior.

In DIMA, DDPM operates purely in latent space (dimension d), so it’s lightweight compared to image DDPMs.

Main knobs (DDPM):

  • T (int): number of diffusion steps. Common: 100..1000. DIMA default: 200.
  • hidden_dim (int): MLP width.
  • t_embed_dim (int): time embedding dimension.
  • n_iter (int): training iterations (more is better for sampling quality).
  • batch_size (int): training batch size.
  • learning_rate (float): Adam learning rate.
  • ema_decay (float): EMA smoothing for stable sampling (typical: 0.999).
  • beta_max (float): caps noise schedule (stability knob).
  • eps (float): numerical stabilizer.
  • verbose_every (int): progress prints.

Refinement knobs (during decoding):

  • refine (bool): enable/disable DDPM refinement.

  • t_start (int): how strongly to “project” using reverse diffusion.

    • small (1..10) = gentle projection
    • larger (20..100) = stronger projection (can oversmooth or drift if DDPM undertrained)
  • add_noise (bool): whether to forward-noise before reverse steps.


API

Core class

from dima import DIMA
dima = DIMA(R_iX)

Encode / decode

Z = dima.encode(X)                      # (B,d) normalized latent (jnp)
X_hat = dima.decode(Z, refine=False)    # (B,D) reconstruction (np)
X_ref = dima.decode(Z, refine=True, t_start=10)  # refined decode

Polymorphic call

Z = dima(X)     # if X.shape[-1] == D
X = dima(Z)     # if Z.shape[-1] == d

Sampling

X_gen = dima.sample(1000)            # ambient samples (np)
Z_gen = dima.sample(1000, decode=False)  # latent samples (jnp)

Configuration patterns

1) Use defaults (cleanest)

dima = DIMA(R_iX)

2) Pass only a few knobs

dima = DIMA(
    R_iX,
    d=64,
    dmap_kwargs=dict(k=256, alpha=1.0, t=0.5),
    gplm_kwargs=dict(m=2048, sigma2=1e-5, pred_k=256),
    ddpm_kwargs=dict(n_iter=50_000, hidden_dim=256),
)

Saving / Loading

Local

dima.save_local("dima.msgpack", "config.json")
dima2 = DIMA.load_local("dima.msgpack", ddpm_device="auto")

Hugging Face (optional)

dima.upload_to_huggingface(repo_id="username/dima-model", hf_token="...")

dima3 = DIMA.load_from_huggingface("username/dima-model", ddpm_device="auto")

Performance notes

  • DMAP: main costs are kNN search + sparse eigensolve.

    • kNN is faster with FAISS.
    • Increasing k increases graph density and compute.
  • GPLM: training cost roughly scales with $N\cdot m$ (streamed by fit_block).

    • Inference cost is $B\cdot m$ if pred_k=None, or $B\cdot \text{pred_k}$ if using inducing kNN.
  • DDPM: scales with latent dimension d, steps T, and training iterations.

    • GPU helps but CPU works for smaller d and fewer iterations.

Troubleshooting

“Unknown backend: gpu”

Your JAX install only has CPU. Use:

dima = DIMA(R_iX, ddpm_device="cpu")

or install a GPU-enabled JAX build.

Reconstructions are blurry / low quality

  • Increase gplm_kwargs["m"]
  • Decrease gplm_kwargs["sigma2"] slightly (careful: too small can destabilize)
  • Set pred_k=None (use all inducing points)

DDPM refinement makes reconstructions worse

  • Reduce t_start (try 3..10)
  • Increase DDPM training (ddpm_kwargs["n_iter"])
  • Disable add_noise for gentler behavior

Citations

  • Diffusion Maps / DMAE inspiration: Diffusion Map AutoEncoder (DMAE).
  • Latent diffusion prior: Denoising Diffusion Probabilistic Models (DDPM).

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dima-0.1.6.tar.gz (32.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dima-0.1.6-py3-none-any.whl (31.8 kB view details)

Uploaded Python 3

File details

Details for the file dima-0.1.6.tar.gz.

File metadata

  • Download URL: dima-0.1.6.tar.gz
  • Upload date:
  • Size: 32.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for dima-0.1.6.tar.gz
Algorithm Hash digest
SHA256 a850fe8bdd287ab65125da4be05a61ad9ead28d4159b35f16e608551227ff79a
MD5 3923cda60abd843ff948de0f2266b477
BLAKE2b-256 5640cead2098f89d7fbd500372135cce08d260a3ab05ee87a58e9a820107a15f

See more details on using hashes here.

File details

Details for the file dima-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: dima-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 31.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for dima-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 04d90a3744f65508a5c155315a2e20b434d34266468f02fc67133d1c3e558bf4
MD5 65bd904f867f5a833282502e22b40a97
BLAKE2b-256 cc8a9e579a6be8fe5308a89a05ee4826924d10d905447dcde3d8118e7d042a10

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page