DIMA: Diffusion–Intrinsic Manifold Autoencoder
Project description
DIMA — Diffusion–Intrinsic Manifold Autoencoder
DIMA combines three components into a practical, scalable manifold autoencoder:
- DMAP: Diffusion Maps encoder (builds a kNN graph in ambient space and embeds points into diffusion coordinates).
- GPLM: Nyström / inducing-point kernel ridge decoder (maps latent diffusion coordinates back to ambient space).
- DDPM: Latent diffusion model (optional) that learns a generative prior over normalized latents and can “refine” latents.
Philosophy: keep DMAP + GPLM fast on CPU/RAM (NumPy/SciPy, sparse ops), and optionally run DDPM on JAX (CPU or GPU).
Installation
CPU-only (recommended to start)
pip install dima
Optional extras
- Faster kNN search (FAISS):
pip install dima[faiss]
- Hugging Face upload/load:
pip install dima[hf]
GPU note (JAX):
pip install dimainstalls CPUjaxlibby default. For GPU, install the correct JAX wheel for your CUDA/ROCm setup first (per JAX docs), then installdima.
Quickstart
import numpy as np
import jax
from dima import DIMA
R_iX = np.random.randn(5000, 16).astype(np.float32)
dima = DIMA(R_iX) # trains DMAP, GPLM, DDPM with defaults
Z = dima.encode(R_iX[:10]) # ambient -> normalized latent (jnp)
X_hat = dima.decode(Z, refine=False) # latent -> ambient (np), no DDPM
X_ref = dima.decode(Z, refine=True, t_start=10) # DDPM refinement
X_gen = dima.sample(1000) # unconditional samples -> ambient (np)
What DIMA trains
1) DMAP encoder (Diffusion Maps)
DMAP builds a kNN graph over ambient data $R_{iX}\in\mathbb{R}^{N\times D}$ and returns diffusion coordinates $R_{ix}\in\mathbb{R}^{N\times d}$.
Key idea (no theory): DMAP produces a geometry-aware latent space where nearby points on the manifold remain nearby in diffusion distance.
Main knobs (DMAP):
d(int): latent dimension.k(int): neighbors in the kNN graph. IfNone, a heuristic is used.beta/β(float): kernel sharpness for the RBF affinity $K_{ij}=\exp{-\beta |x_i-x_j|^2/\varepsilon}$.eps/ε(float or None): kernel bandwidth. IfNone, estimated from kNN distances (median heuristic).alpha/α(float): density normalization exponent. Common values:0.0(none) or1.0(often robust).t(float): diffusion time exponent (scales eigenvalues as $\lambda^t$). Often0.5or1.0.drop_trivial(bool): drops the top eigenvector/eigenvalue (the constant mode).sym(str): symmetrization mode for sparse kNN kernel graph:"max"or"mean".ann_backend(str):"auto" | "faiss" | "pynndescent" | "sklearn" | "brute".
Typical DMAP presets
- Fast-ish and stable:
k=128..512,alpha=1.0,t=0.5,drop_trivial=True. - If your data is very noisy, try larger
kand/or largereps_mul.
2) GPLM decoder (Nyström kernel ridge / inducing GP)
GPLM learns a mapping from latents back to ambient:
- Inputs: latent training points $R_{ix}\in\mathbb{R}^{N\times d}$
- Targets: ambient training points $R_{iX}\in\mathbb{R}^{N\times D}$
Instead of a full $N\times N$ kernel solve, GPLM uses inducing points $Z_{mx}$ with $m\ll N$ and solves a reduced system:
- Build affinities $C_{im}=\exp{-\beta|R_{ix}-Z_{mx}|^2/\varepsilon}$
- Solve a stabilized kernel ridge / GP mean system to obtain weights $M_{mX}$
- Predict: $\hat R_{aX}=C_{am}M_{mX}+\mu_X$
Main knobs (GPLM):
-
m(int): number of inducing points. Bigger → better accuracy, more compute. -
inducing(str): inducing strategy:"kmeans_medoids"(default): kmeans centers snapped to nearest training latent."fps": farthest-point sampling (space-filling)."random_subset": fastest."given": use providedZ_mx.
-
sigma2/σ2(float): ridge regularization. Too small can overfit / cause instability; too large blurs reconstructions. -
jitter(float): tiny diagonal stabilizer for Cholesky. -
eps/ε(float or None): RBF bandwidth in latent space. IfNone, estimated from latent kNN distances. -
k_eps/κ_eps(int): neighbors used for the $\varepsilon$ heuristic. -
pred_k/pred_κ(int or None): prediction-time inducing neighbors:Nonemeans use all inducing points (best accuracy).- a small number (e.g.
128or256) speeds inference (slightly lower accuracy).
-
whiten_latent(bool): optionally standardize latent dimensions before kernel computation. -
center_X(bool): subtract and re-add ambient mean (usually helpful). -
fit_block(int): block size for streaming $C^T C$ accumulation (memory/perf knob). -
ann_backend(str): same options as DMAP.
Typical GPLM presets
- Accurate:
m=1024..4096,pred_k=None,sigma2=1e-5(tune). - Faster inference: set
pred_k=128..512.
3) DDPM latent diffusion (optional generative prior)
DDPM learns a distribution over normalized latents: [ Z = \frac{R_x - \mu}{\sigma} ] and can:
- sample new latents,
- refine a given latent by projecting it onto the learned latent manifold/prior.
In DIMA, DDPM operates purely in latent space (dimension d), so it’s lightweight compared to image DDPMs.
Main knobs (DDPM):
T(int): number of diffusion steps. Common:100..1000. DIMA default:200.hidden_dim(int): MLP width.t_embed_dim(int): time embedding dimension.n_iter(int): training iterations (more is better for sampling quality).batch_size(int): training batch size.learning_rate(float): Adam learning rate.ema_decay(float): EMA smoothing for stable sampling (typical:0.999).beta_max(float): caps noise schedule (stability knob).eps(float): numerical stabilizer.verbose_every(int): progress prints.
Refinement knobs (during decoding):
-
refine(bool): enable/disable DDPM refinement. -
t_start(int): how strongly to “project” using reverse diffusion.- small (
1..10) = gentle projection - larger (
20..100) = stronger projection (can oversmooth or drift if DDPM undertrained)
- small (
-
add_noise(bool): whether to forward-noise before reverse steps.
API
Core class
from dima import DIMA
dima = DIMA(R_iX)
Encode / decode
Z = dima.encode(X) # (B,d) normalized latent (jnp)
X_hat = dima.decode(Z, refine=False) # (B,D) reconstruction (np)
X_ref = dima.decode(Z, refine=True, t_start=10) # refined decode
Polymorphic call
Z = dima(X) # if X.shape[-1] == D
X = dima(Z) # if Z.shape[-1] == d
Sampling
X_gen = dima.sample(1000) # ambient samples (np)
Z_gen = dima.sample(1000, decode=False) # latent samples (jnp)
Configuration patterns
1) Use defaults (cleanest)
dima = DIMA(R_iX)
2) Pass only a few knobs
dima = DIMA(
R_iX,
d=64,
dmap_kwargs=dict(k=256, alpha=1.0, t=0.5),
gplm_kwargs=dict(m=2048, sigma2=1e-5, pred_k=256),
ddpm_kwargs=dict(n_iter=50_000, hidden_dim=256),
)
Saving / Loading
Local
dima.save_local("dima.msgpack", "config.json")
dima2 = DIMA.load_local("dima.msgpack", ddpm_device="auto")
Hugging Face (optional)
dima.upload_to_huggingface(repo_id="username/dima-model", hf_token="...")
dima3 = DIMA.load_from_huggingface("username/dima-model", ddpm_device="auto")
Performance notes
-
DMAP: main costs are kNN search + sparse eigensolve.
- kNN is faster with FAISS.
- Increasing
kincreases graph density and compute.
-
GPLM: training cost roughly scales with $N\cdot m$ (streamed by
fit_block).- Inference cost is $B\cdot m$ if
pred_k=None, or $B\cdot \text{pred_k}$ if using inducing kNN.
- Inference cost is $B\cdot m$ if
-
DDPM: scales with latent dimension
d, stepsT, and training iterations.- GPU helps but CPU works for smaller
dand fewer iterations.
- GPU helps but CPU works for smaller
Troubleshooting
“Unknown backend: gpu”
Your JAX install only has CPU. Use:
dima = DIMA(R_iX, ddpm_device="cpu")
or install a GPU-enabled JAX build.
Reconstructions are blurry / low quality
- Increase
gplm_kwargs["m"] - Decrease
gplm_kwargs["sigma2"]slightly (careful: too small can destabilize) - Set
pred_k=None(use all inducing points)
DDPM refinement makes reconstructions worse
- Reduce
t_start(try3..10) - Increase DDPM training (
ddpm_kwargs["n_iter"]) - Disable
add_noisefor gentler behavior
Citations
- Diffusion Maps / DMAE inspiration: Diffusion Map AutoEncoder (DMAE).
- Latent diffusion prior: Denoising Diffusion Probabilistic Models (DDPM).
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dima-0.1.1.tar.gz.
File metadata
- Download URL: dima-0.1.1.tar.gz
- Upload date:
- Size: 28.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
939f207cdb4e2d846758e7f04969ff3bea1b2046c42332e5d699188fa7ff7f2b
|
|
| MD5 |
602ae3b63139ee2421b44b3246fc7cb0
|
|
| BLAKE2b-256 |
aed3eb4e180493dcaeb8696f9c7905f72504e347e0132c3a2a9c32e1c90031e0
|
File details
Details for the file dima-0.1.1-py3-none-any.whl.
File metadata
- Download URL: dima-0.1.1-py3-none-any.whl
- Upload date:
- Size: 28.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ee81ee9970b2532aab6d6334d0dc863282c6b2bbf08e379cafab74482df1c005
|
|
| MD5 |
7783746d0113901fd467f4acc0216815
|
|
| BLAKE2b-256 |
f58a372d3c2e53c801a0c34896398a6924feddb7f7c73056a719d76ef3b7fd09
|