Skip to main content

Windowed Diffusion Transformer in 2D and 3D

Project description

WiDiT banner

pypi badge testing badge coverage badge docs badge black badge

WiDiT is a SwinIR-style DiT backbone that unifies 2D images and 3D volumes with N-D windowed attention, optional Swin shifts, and AdaLN-Zero conditioning.

  • Single model class: widit.models.WiDiT

  • Optional timestep conditioning (pass timestep=None if unused)

  • Shared blocks for 2D/3D via N-D window partitioning

  • Presets for quick experiments in both 2D and 3D

Installation

Install using pip:

pip install widit

Or

pip install git+https://github.com/rbturnbull/widit.git

WiDiT depends on torch and timm (for the 2D patch embedding path).

Quick Start (2D)

import torch
from widit.models import WiDiT

# Example: 2D RGB input & conditioning (e.g., low-res guidance)
N, C, H, W = 2, 3, 128, 96
x      = torch.randn(N, C, H, W)
cond   = torch.randn_like(x)
t      = torch.randint(0, 1000, (N,), dtype=torch.long)  # optional

model = WiDiT(
    spatial_dim=2,
    input_size=(H, W),        # kept for API parity; not required at forward
    patch_size=2,             # must divide H and W
    in_channels=C,
    hidden_size=256,          # must be divisible by num_heads and even
    depth=6,
    num_heads=8,
    window_size=8,            # can be int or (wh, ww)
    mlp_ratio=4.0,
    learn_sigma=True,         # output channels = 2*C if True
)

# Timestep is optional; pass None to disable conditioning
y = model(x, cond, t)         # (N, 2*C, H, W) if learn_sigma=True

Quick Start (3D)

import torch
from widit.models import WiDiT

# Example: 3D single-channel volumes
N, C, D, H, W = 1, 1, 64, 64, 48
x    = torch.randn(N, C, D, H, W)
cond = torch.randn_like(x)

model = WiDiT(
    spatial_dim=3,
    input_size=(D, H, W),
    patch_size=2,             # must divide D/H/W
    in_channels=C,
    hidden_size=256,
    depth=4,
    num_heads=8,
    window_size=(4, 4, 4),    # can be int or (wd, wh, ww)
    mlp_ratio=4.0,
    learn_sigma=False,        # output channels = C if False
)

y = model(x, cond, timestep=None)  # (N, C, D, H, W)

Presets

Presets provide ready-made configurations for common model sizes (2D & 3D), all using patch_size=2 and Swin-style window attention:

from widit.models import PRESETS

# 2D: B, M, L, XL
model_2d = PRESETS["WiDiT-L/2"](in_channels=3, learn_sigma=True)

# 3D: B, M, L, XL
model_3d = PRESETS["WiDiT3D-M/2"](in_channels=1, learn_sigma=False)

# Run
y2d = model_2d(x2d, cond2d, timestep=None)
y3d = model_3d(x3d, cond3d, timestep=torch.randint(0, 1000, (x3d.shape[0],)))

API Overview

WiDiT(
    *,
    spatial_dim: int,                          # 2 (images) or 3 (volumes)
    input_size: int | Sequence[int] | None = None,
    patch_size: int | Sequence[int] = 2,       # per-axis tuple allowed
    in_channels: int = 1,
    hidden_size: int = 768,                    # even; divisible by num_heads
    depth: int = 12,
    num_heads: int = 12,
    window_size: int | Sequence[int] = 8,      # per-axis tuple allowed
    mlp_ratio: float = 4.0,
    learn_sigma: bool = True,
)

forward(
    input_tensor:       torch.Tensor,          # (N, C, *spatial)
    conditioned_tensor: torch.Tensor,          # (N, C, *spatial), same shape as input_tensor
    timestep:           torch.Tensor | None = None,  # (N,) or None
) -> torch.Tensor                              # (N, out_channels, *spatial)

Shapes & contracts

  • *spatial is (H, W) for 2D and (D, H, W) for 3D.

  • patch_size must evenly divide each spatial dimension.

  • window_size can be an int or a per-axis tuple; internal padding ensures full windows (removed before returning).

  • hidden_size must be even (split across the two patch embedders) and divisible by num_heads.

  • If learn_sigma=True, output channels = 2 * in_channels (mean + sigma style).

Conditioning

  • timestep is optional. Pass None to disable AdaLN conditioning (the block falls back to standard LN + residual).

  • If provided, the model uses widit.timesteps.TimestepEmbedder to produce a per-sample vector projected to the token dimension.

Building Blocks

These are used internally, but you can also import them for custom stacks.

  • widit.blocks.WiDiTBlock – N-D windowed MSA + MLP with AdaLN-Zero

  • widit.blocks.WiDiTFinalLayer – final projection head with AdaLN-Zero

  • widit.patch.PatchEmbed – unified 2D/3D patch embedding

  • widit.timesteps.TimestepEmbedder – sinusoidal → MLP conditioning

All of the above expose init_weights() so the model can initialize components cleanly (adaLN-Zero policy for blocks & head; Xavier for projections; Normal for timestep MLP weights).

Training Snippet

import torch
from torch.optim import AdamW
from widit.models import WiDiT

device = "cuda" if torch.cuda.is_available() else "cpu"

model = WiDiT(
    spatial_dim=2,
    in_channels=3,
    hidden_size=256,
    depth=6,
    num_heads=8,
    patch_size=2,
    window_size=8,
    learn_sigma=True,
).to(device)

opt = AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

for step in range(100):
    x    = torch.randn(8, 3, 128, 96, device=device)
    cond = torch.randn_like(x)
    t    = torch.randint(0, 1000, (x.shape[0],), device=device)

    y = model(x, cond, t)                      # (N, 6, H, W) here (mean+sigma for C=3)
    target = torch.randn_like(y)

    loss = torch.nn.functional.mse_loss(y, target)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

Tips & Gotchas

  • Patch size equality in unpatchify: currently the unpatchify path enforces equal patch size along all axes (e.g., patch_size=2 or (2,2,2)). Mixed per-axis patch sizes for output reconstruction are not supported yet.

  • Token grid divisibility: ensure every spatial dimension is divisible by patch_size. Window attention will pad internally to complete windows and crop back, but patch embedding is stride-based.

  • Timestep optional: pass timestep=None to run the model without diffusion conditioning (AdaLN defaults reduce to a vanilla transformer residual path).

  • Mixed precision: standard AMP (torch.cuda.amp) works out-of-the-box.

Reference Shapes

2D

  • Input: (N, C, H, W)

  • Output: (N, 2*C, H, W) if learn_sigma=True, else (N, C, H, W)

3D

  • Input: (N, C, D, H, W)

  • Output: (N, 2*C, D, H, W) if learn_sigma=True, else (N, C, D, H, W)

Credits

Robert Turnbull - Melbourne Data Analytics Platform (MDAP), The University of Melbourne

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

widit-0.1.0a1.tar.gz (18.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

widit-0.1.0a1-py3-none-any.whl (19.3 kB view details)

Uploaded Python 3

File details

Details for the file widit-0.1.0a1.tar.gz.

File metadata

  • Download URL: widit-0.1.0a1.tar.gz
  • Upload date:
  • Size: 18.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.1 CPython/3.13.1 Darwin/24.6.0

File hashes

Hashes for widit-0.1.0a1.tar.gz
Algorithm Hash digest
SHA256 56968319e3bc9c9daf4321f9ed62942e05ab6c8825565e42486920daca67e49f
MD5 aa146beffdc693302d8ec5ca083f4a73
BLAKE2b-256 40ec868bea1dd1c030b93c02b547c808307c857e50d4b3abaf2f418b531cb5fb

See more details on using hashes here.

File details

Details for the file widit-0.1.0a1-py3-none-any.whl.

File metadata

  • Download URL: widit-0.1.0a1-py3-none-any.whl
  • Upload date:
  • Size: 19.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.1 CPython/3.13.1 Darwin/24.6.0

File hashes

Hashes for widit-0.1.0a1-py3-none-any.whl
Algorithm Hash digest
SHA256 c48a3eb507d6641fc5345caed9ca4be2b39bfab6bc3e813821f38dcc68ebfd3b
MD5 90f0235123fbd98703583711abf40839
BLAKE2b-256 bbb4638eeefa9985d29ece50e2fca561ff7778629be836e5641b0512a0698831

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page