Skip to main content

Differentiable CT reconstruction for Apple Silicon using MLX and custom Metal kernels

Project description

diffct_mlx: Differentiable CT for Apple Silicon

License DOI

A high-performance, differentiable computed tomography (CT) reconstruction library built with MLX and custom Metal kernels, optimized for Apple Silicon (M-series) chips.

This is the Apple Silicon port of diffct, replacing CUDA/PyTorch with MLX/Metal for native M-series GPU acceleration.

Features

  • Apple Silicon Native: Custom Metal kernels via mx.fast.metal_kernel — no CUDA required
  • Differentiable: End-to-end gradient propagation using mx.custom_function with custom VJPs
  • Siddon Ray-Tracing: Bilinear (2D) and trilinear (3D) interpolation for accurate projection
  • Atomic Backprojection: Thread-safe gradient accumulation using Metal atomic operations

Supported Geometries

Geometry Forward Backward Differentiable
2D Parallel Beam
2D Fan Beam
3D Cone Beam

Trajectory Generators

  • Circular — standard single-rotation scan
  • Spiral / Helical — helical CT with z-axis translation (3D)
  • Sinusoidal — variable source-to-isocenter distance
  • Saddle — combined z-oscillation and radial variation (3D)
  • Random — perturbed circular with configurable noise (3D)
  • Custom — user-defined source path functions

Quick Start

Prerequisites

  • Apple Silicon Mac (M1/M2/M3/M4 series)
  • Python 3.10+
  • macOS 13.5+

Installation from PyPI

pip install diffct_mlx

Installation from Source

# Clone the repository
git clone https://github.com/Linda-SophieSchneider/diffct_arbit.git
cd diffct_arbit

# Create and activate conda environment
conda create -n diffct-mlx python=3.11
conda activate diffct-mlx

# Install MLX and dependencies
pip install mlx numpy matplotlib

# Install diffct_mlx
pip install -e .

Build a Release Distribution

python -m pip install --upgrade build twine
python -m build
python -m twine check dist/*

This creates both a source distribution and a wheel in dist/.

Publish to TestPyPI or PyPI

# TestPyPI
python -m twine upload --repository testpypi dist/*

# PyPI
python -m twine upload dist/*

For publishing from your machine, create an API token on PyPI and use it as the password with the username __token__.

If you publish through GitHub Actions, prefer PyPI Trusted Publishing instead of storing a long-lived API token.

The repository release workflow is documented in RELEASING.md.

Basic Usage

import mlx.core as mx
import diffct_mlx

# Create a 64x64 test image
image = mx.ones((64, 64), dtype=mx.float32)

# Generate parallel beam geometry (90 views)
ray_dir, det_origin, det_u_vec = diffct_mlx.circular_trajectory_2d_parallel(90)

# Forward projection → sinogram
sino = diffct_mlx.parallel_forward(
    image, ray_dir, det_origin, det_u_vec,
    num_detectors=92, detector_spacing=1.0, voxel_spacing=1.0
)

# Backprojection → reconstruction
reco = diffct_mlx.parallel_backward(
    sino, ray_dir, det_origin, det_u_vec,
    H=64, W=64, detector_spacing=1.0, voxel_spacing=1.0
)

Gradient Computation

Since all projectors are differentiable, you can compute gradients directly:

import mlx.core as mx
import diffct_mlx

def loss_fn(image):
    ray_dir, det_origin, det_u_vec = diffct_mlx.circular_trajectory_2d_parallel(90)
    sino = diffct_mlx.parallel_forward(image, ray_dir, det_origin, det_u_vec, 92)
    return mx.sum(sino ** 2)

image = mx.ones((64, 64), dtype=mx.float32)
grad_fn = mx.grad(loss_fn)
gradient = grad_fn(image)

3D Cone Beam Example

import mlx.core as mx
import diffct_mlx

# Create a 32x64x64 volume (D, H, W)
volume = mx.ones((32, 64, 64), dtype=mx.float32)

# Generate cone beam geometry
src, det_c, det_u, det_v = diffct_mlx.circular_trajectory_3d(
    n_views=60, sid=500.0, sdd=1000.0
)

# Forward projection
sino = diffct_mlx.cone_forward(
    volume, src, det_c, det_u, det_v,
    det_u=64, det_v=32, du=1.0, dv=1.0, voxel_spacing=1.0
)

# Backprojection
reco = diffct_mlx.cone_backward(
    sino, src, det_c, det_u, det_v,
    D=32, H=64, W=64, du=1.0, dv=1.0, voxel_spacing=1.0
)

API Reference

Projectors

Function Description
parallel_forward(image, ray_dir, det_origin, det_u_vec, ...) 2D parallel beam forward projection
parallel_backward(sinogram, ray_dir, det_origin, det_u_vec, ...) 2D parallel beam backprojection
fan_forward(image, src_pos, det_center, det_u_vec, ...) 2D fan beam forward projection
fan_backward(sinogram, src_pos, det_center, det_u_vec, ...) 2D fan beam backprojection
cone_forward(volume, src_pos, det_center, det_u_vec, det_v_vec, ...) 3D cone beam forward projection
cone_backward(sinogram, src_pos, det_center, det_u_vec, det_v_vec, ...) 3D cone beam backprojection

Iterative Reconstruction Algorithms

The package now also exposes callback-based iterative reconstruction algorithms that are independent of geometry and dimensionality:

Function Description
run_sart(...) SART with user-provided single-view forward/backprojectors
run_tv_pocs(...) TV-POCS using the same projector callback pattern
run_asd_pocs(...) ASD-POCS with adaptive TV step-size damping
run_awtv_pocs(...) AwTV-POCS with edge-adaptive weighted TV regularization

Each algorithm takes:

  • measured_projections: a list/sequence of per-view projections
  • forward_project(volume, projection_index): user callback for one view
  • back_project(projection, projection_index): user callback for one view
  • ReconstructionParameters: shared reconstruction settings
  • an algorithm-specific regularization dataclass where applicable

This keeps the algorithms reusable across parallel, fan, cone, 2D, and 3D setups as long as the caller provides the appropriate single-view projector wrappers.

Trajectory Generators

Function Geometry
circular_trajectory_2d_parallel(n_views, ...) 2D parallel
sinusoidal_trajectory_2d_parallel(n_views, ...) 2D parallel
custom_trajectory_2d_parallel(n_views, ...) 2D parallel
circular_trajectory_2d_fan(n_views, sid, sdd, ...) 2D fan
sinusoidal_trajectory_2d_fan(n_views, sid, sdd, ...) 2D fan
custom_trajectory_2d_fan(n_views, sid, sdd, ...) 2D fan
circular_trajectory_3d(n_views, sid, sdd, ...) 3D cone
spiral_trajectory_3d(n_views, sid, sdd, ...) 3D cone
sinusoidal_trajectory_3d(n_views, sid, sdd, ...) 3D cone
saddle_trajectory_3d(n_views, sid, sdd, ...) 3D cone
random_trajectory_3d(n_views, sid_mean, sdd_mean, ...) 3D cone
custom_trajectory_3d(n_views, sid, sdd, ...) 3D cone

Examples

Ready-to-run scripts are provided in the examples/ directory:

Circular Trajectory (Analytical Reconstruction)

Script Description
examples/circular_trajectory/fbp_parallel.py FBP with ramp filter — 2D parallel beam
examples/circular_trajectory/fbp_fan.py FBP with cosine weighting + ramp filter — 2D fan beam
examples/circular_trajectory/fdk_cone.py FDK with distance weighting + ramp filter — 3D cone beam

Non-Circular Trajectory (Iterative Reconstruction)

Script Description
examples/non_circular_trajectory/iterative_reco_parallel.py Gradient-based iterative reco — sinusoidal & custom wobble trajectories
examples/non_circular_trajectory/iterative_reco_fan.py Gradient-based iterative reco — sinusoidal & custom elliptical trajectories
examples/non_circular_trajectory/iterative_reco_cone.py Gradient-based iterative reco — spiral, sinusoidal, saddle & figure-8 trajectories

Run any example with:

conda activate diffct-mlx
python examples/circular_trajectory/fbp_parallel.py

Package Structure

diffct_mlx/
├── __init__.py          # Public API exports
├── constants.py         # MLX-specific constants and dtypes
├── utils.py             # Grid computation utilities
├── geometry.py          # Trajectory generation functions
├── projectors.py        # Differentiable projector functions with VJPs
└── kernels/
    ├── __init__.py
    ├── parallel_beam.py # Metal kernels for 2D parallel beam
    ├── fan_beam.py      # Metal kernels for 2D fan beam
    └── cone_beam.py     # Metal kernels for 3D cone beam

Citation

If you use this library in your research, please cite:

@software{diffct2026,
  author       = {Yipeng Sun, Linda-Sophie Schneider},
  title        = {diffct_mlx: Differentiable CT for Apple Silicon},
  year         = 2026,
}

License

This project is licensed under the Apache 2.0 License — see the LICENSE file for details.

Acknowledgements

This project was highly inspired by:

Issues and contributions are welcome!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffct_mlx-1.0.1.tar.gz (52.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffct_mlx-1.0.1-py3-none-any.whl (65.9 kB view details)

Uploaded Python 3

File details

Details for the file diffct_mlx-1.0.1.tar.gz.

File metadata

  • Download URL: diffct_mlx-1.0.1.tar.gz
  • Upload date:
  • Size: 52.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for diffct_mlx-1.0.1.tar.gz
Algorithm Hash digest
SHA256 231bb147e09a759d9c372d9f1e46179c427b4359f553b502468b319f4d1cffc9
MD5 0c68bcd72b7462411099258d4a450cea
BLAKE2b-256 502a39a782ad5322b4dd9a6be9a75f859b1234ec56c820c5999f134cb076eff3

See more details on using hashes here.

Provenance

The following attestation bundles were made for diffct_mlx-1.0.1.tar.gz:

Publisher: publish.yml on Linda-SophieSchneider/DiffCT-MLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file diffct_mlx-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: diffct_mlx-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 65.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for diffct_mlx-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0f21a698b1f0f2ec1e7baa0947390fdb8e16539c2c61bc9b8f56dd0393d11c6b
MD5 2687d4fbd9988d1190deadceb27f1650
BLAKE2b-256 a9af3a569a41b27fe94ab4a86449a1560eb1cf2b9690539bbb6b2b7810aa3359

See more details on using hashes here.

Provenance

The following attestation bundles were made for diffct_mlx-1.0.1-py3-none-any.whl:

Publisher: publish.yml on Linda-SophieSchneider/DiffCT-MLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page