Skip to main content

DIMA: Diffusion–Isocoder Manifold-Autoencoder

Project description

DIMA — Diffusion–Isocoder Manifold-Autoencoder

DIMA combines three components into a practical, scalable manifold autoencoder:

  • DMAP: Diffusion Maps encoder (builds a kNN graph in ambient space and embeds points into diffusion coordinates).
  • GPLM: Nyström / inducing-point kernel ridge decoder (maps latent diffusion coordinates back to ambient space).
  • DDPM: Latent diffusion model (optional) that learns a generative prior over normalized latents and can “refine” latents.

Philosophy: keep DMAP + GPLM fast on CPU/RAM (NumPy/SciPy, sparse ops), and optionally run DDPM on JAX (CPU or GPU).


Installation

CPU-only (recommended to start)

pip install dima

Optional extras

  • Faster kNN search (FAISS):
pip install dima[faiss]
  • Hugging Face upload/load:
pip install dima[hf]

GPU note (JAX): pip install dima installs CPU jaxlib by default. For GPU, install the correct JAX wheel for your CUDA/ROCm setup first (per JAX docs), then install dima.


Quickstart

import numpy as np
import jax
from dima import DIMA

R_iX = np.random.randn(5000, 16).astype(np.float32)

dima = DIMA(R_iX)                 # trains DMAP, GPLM, DDPM with defaults

Z = dima.encode(R_iX[:10])        # ambient -> normalized latent (jnp)
X_hat = dima.decode(Z, refine=False)  # latent -> ambient (np), no DDPM
X_ref = dima.decode(Z, refine=True, t_start=10)  # DDPM refinement

X_gen = dima.sample(1000)         # unconditional samples -> ambient (np)

What DIMA trains

1) DMAP encoder (Diffusion Maps)

DMAP builds a kNN graph over ambient data $R_{iX}\in\mathbb{R}^{N\times D}$ and returns diffusion coordinates $R_{ix}\in\mathbb{R}^{N\times d}$.

Key idea (no theory): DMAP produces a geometry-aware latent space where nearby points on the manifold remain nearby in diffusion distance.

Main knobs (DMAP):

  • d (int): latent dimension.
  • k (int): neighbors in the kNN graph. If None, a heuristic is used.
  • beta / β (float): kernel sharpness for the RBF affinity $K_{ij}=\exp{-\beta |x_i-x_j|^2/\varepsilon}$.
  • eps / ε (float or None): kernel bandwidth. If None, estimated from kNN distances (median heuristic).
  • alpha / α (float): density normalization exponent. Common values: 0.0 (none) or 1.0 (often robust).
  • t (float): diffusion time exponent (scales eigenvalues as $\lambda^t$). Often 0.5 or 1.0.
  • drop_trivial (bool): drops the top eigenvector/eigenvalue (the constant mode).
  • sym (str): symmetrization mode for sparse kNN kernel graph: "max" or "mean".
  • ann_backend (str): "auto" | "faiss" | "pynndescent" | "sklearn" | "brute".

Typical DMAP presets

  • Fast-ish and stable: k=128..512, alpha=1.0, t=0.5, drop_trivial=True.
  • If your data is very noisy, try larger k and/or larger eps_mul.

2) GPLM decoder (Nyström kernel ridge / inducing GP)

GPLM learns a mapping from latents back to ambient:

  • Inputs: latent training points $R_{ix}\in\mathbb{R}^{N\times d}$
  • Targets: ambient training points $R_{iX}\in\mathbb{R}^{N\times D}$

Instead of a full $N\times N$ kernel solve, GPLM uses inducing points $Z_{mx}$ with $m\ll N$ and solves a reduced system:

  • Build affinities $C_{im}=\exp{-\beta|R_{ix}-Z_{mx}|^2/\varepsilon}$
  • Solve a stabilized kernel ridge / GP mean system to obtain weights $M_{mX}$
  • Predict: $\hat R_{aX}=C_{am}M_{mX}+\mu_X$

Main knobs (GPLM):

  • m (int): number of inducing points. Bigger → better accuracy, more compute.

  • inducing (str): inducing strategy:

    • "kmeans_medoids" (default): kmeans centers snapped to nearest training latent.
    • "fps": farthest-point sampling (space-filling).
    • "random_subset": fastest.
    • "given": use provided Z_mx.
  • sigma2 / σ2 (float): ridge regularization. Too small can overfit / cause instability; too large blurs reconstructions.

  • jitter (float): tiny diagonal stabilizer for Cholesky.

  • eps / ε (float or None): RBF bandwidth in latent space. If None, estimated from latent kNN distances.

  • k_eps / κ_eps (int): neighbors used for the $\varepsilon$ heuristic.

  • pred_k / pred_κ (int or None): prediction-time inducing neighbors:

    • None means use all inducing points (best accuracy).
    • a small number (e.g. 128 or 256) speeds inference (slightly lower accuracy).
  • whiten_latent (bool): optionally standardize latent dimensions before kernel computation.

  • center_X (bool): subtract and re-add ambient mean (usually helpful).

  • fit_block (int): block size for streaming $C^T C$ accumulation (memory/perf knob).

  • ann_backend (str): same options as DMAP.

Typical GPLM presets

  • Accurate: m=1024..4096, pred_k=None, sigma2=1e-5 (tune).
  • Faster inference: set pred_k=128..512.

3) DDPM latent diffusion (optional generative prior)

DDPM learns a distribution over normalized latents: [ Z = \frac{R_x - \mu}{\sigma} ] and can:

  • sample new latents,
  • refine a given latent by projecting it onto the learned latent manifold/prior.

In DIMA, DDPM operates purely in latent space (dimension d), so it’s lightweight compared to image DDPMs.

Main knobs (DDPM):

  • T (int): number of diffusion steps. Common: 100..1000. DIMA default: 200.
  • hidden_dim (int): MLP width.
  • t_embed_dim (int): time embedding dimension.
  • n_iter (int): training iterations (more is better for sampling quality).
  • batch_size (int): training batch size.
  • learning_rate (float): Adam learning rate.
  • ema_decay (float): EMA smoothing for stable sampling (typical: 0.999).
  • beta_max (float): caps noise schedule (stability knob).
  • eps (float): numerical stabilizer.
  • verbose_every (int): progress prints.

Refinement knobs (during decoding):

  • refine (bool): enable/disable DDPM refinement.

  • t_start (int): how strongly to “project” using reverse diffusion.

    • small (1..10) = gentle projection
    • larger (20..100) = stronger projection (can oversmooth or drift if DDPM undertrained)
  • add_noise (bool): whether to forward-noise before reverse steps.


API

Core class

from dima import DIMA
dima = DIMA(R_iX)

Encode / decode

Z = dima.encode(X)                      # (B,d) normalized latent (jnp)
X_hat = dima.decode(Z, refine=False)    # (B,D) reconstruction (np)
X_ref = dima.decode(Z, refine=True, t_start=10)  # refined decode

Polymorphic call

Z = dima(X)     # if X.shape[-1] == D
X = dima(Z)     # if Z.shape[-1] == d

Sampling

X_gen = dima.sample(1000)            # ambient samples (np)
Z_gen = dima.sample(1000, decode=False)  # latent samples (jnp)

Configuration patterns

1) Use defaults (cleanest)

dima = DIMA(R_iX)

2) Pass only a few knobs

dima = DIMA(
    R_iX,
    d=64,
    dmap_kwargs=dict(k=256, alpha=1.0, t=0.5),
    gplm_kwargs=dict(m=2048, sigma2=1e-5, pred_k=256),
    ddpm_kwargs=dict(n_iter=50_000, hidden_dim=256),
)

Saving / Loading

Local

dima.save_local("dima.msgpack", "config.json")
dima2 = DIMA.load_local("dima.msgpack", ddpm_device="auto")

Hugging Face (optional)

dima.upload_to_huggingface(repo_id="username/dima-model", hf_token="...")

dima3 = DIMA.load_from_huggingface("username/dima-model", ddpm_device="auto")

Performance notes

  • DMAP: main costs are kNN search + sparse eigensolve.

    • kNN is faster with FAISS.
    • Increasing k increases graph density and compute.
  • GPLM: training cost roughly scales with $N\cdot m$ (streamed by fit_block).

    • Inference cost is $B\cdot m$ if pred_k=None, or $B\cdot \text{pred_k}$ if using inducing kNN.
  • DDPM: scales with latent dimension d, steps T, and training iterations.

    • GPU helps but CPU works for smaller d and fewer iterations.

Troubleshooting

“Unknown backend: gpu”

Your JAX install only has CPU. Use:

dima = DIMA(R_iX, ddpm_device="cpu")

or install a GPU-enabled JAX build.

Reconstructions are blurry / low quality

  • Increase gplm_kwargs["m"]
  • Decrease gplm_kwargs["sigma2"] slightly (careful: too small can destabilize)
  • Set pred_k=None (use all inducing points)

DDPM refinement makes reconstructions worse

  • Reduce t_start (try 3..10)
  • Increase DDPM training (ddpm_kwargs["n_iter"])
  • Disable add_noise for gentler behavior

Citations

  • Diffusion Maps / DMAE inspiration: Diffusion Map AutoEncoder (DMAE).
  • Latent diffusion prior: Denoising Diffusion Probabilistic Models (DDPM).

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dima-0.1.5.tar.gz (38.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dima-0.1.5-py3-none-any.whl (46.6 kB view details)

Uploaded Python 3

File details

Details for the file dima-0.1.5.tar.gz.

File metadata

  • Download URL: dima-0.1.5.tar.gz
  • Upload date:
  • Size: 38.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for dima-0.1.5.tar.gz
Algorithm Hash digest
SHA256 78625db80a434e4113d2747cd33bb174c5c35e6b1c66a55dd8259295fbd87cb5
MD5 ba0b03cd8844f094318d5eef10fa81be
BLAKE2b-256 5574c98acc5549efd34a2d677383cf67d810db0b2e33014fdf2391c96458f458

See more details on using hashes here.

File details

Details for the file dima-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: dima-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 46.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for dima-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 8da4e35c89fb036be7154aa86ce4281d4f68388b02e65382e9a0b2e8ea446204
MD5 94686f0e6524550b83c3c4c3fa770b3a
BLAKE2b-256 4f251fdf8c41cc0749ac1b93281ec399135c968ff5755b9230ca53708b741f06

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page