Differentiable CT reconstruction for Apple Silicon using MLX and custom Metal kernels
Project description
diffct_mlx: Differentiable CT for Apple Silicon
A high-performance, differentiable computed tomography (CT) reconstruction library built with MLX and custom Metal kernels, optimized for Apple Silicon (M-series) chips.
This is the Apple Silicon port of diffct, replacing CUDA/PyTorch with MLX/Metal for native M-series GPU acceleration.
Features
- Apple Silicon Native: Custom Metal kernels via
mx.fast.metal_kernel— no CUDA required - Differentiable: End-to-end gradient propagation using
mx.custom_functionwith custom VJPs - Siddon Ray-Tracing: Bilinear (2D) and trilinear (3D) interpolation for accurate projection
- Atomic Backprojection: Thread-safe gradient accumulation using Metal atomic operations
Supported Geometries
| Geometry | Forward | Backward | Differentiable |
|---|---|---|---|
| 2D Parallel Beam | ✅ | ✅ | ✅ |
| 2D Fan Beam | ✅ | ✅ | ✅ |
| 3D Cone Beam | ✅ | ✅ | ✅ |
Trajectory Generators
- Circular — standard single-rotation scan
- Spiral / Helical — helical CT with z-axis translation (3D)
- Sinusoidal — variable source-to-isocenter distance
- Saddle — combined z-oscillation and radial variation (3D)
- Random — perturbed circular with configurable noise (3D)
- Custom — user-defined source path functions
Quick Start
Prerequisites
- Apple Silicon Mac (M1/M2/M3/M4 series)
- Python 3.10+
- macOS 13.5+
Installation from PyPI
pip install diffct_mlx
Installation from Source
# Clone the repository
git clone https://github.com/Linda-SophieSchneider/diffct_arbit.git
cd diffct_arbit
# Create and activate conda environment
conda create -n diffct-mlx python=3.11
conda activate diffct-mlx
# Install MLX and dependencies
pip install mlx numpy matplotlib
# Install diffct_mlx
pip install -e .
Build a Release Distribution
python -m pip install --upgrade build twine
python -m build
python -m twine check dist/*
This creates both a source distribution and a wheel in dist/.
Publish to TestPyPI or PyPI
# TestPyPI
python -m twine upload --repository testpypi dist/*
# PyPI
python -m twine upload dist/*
For publishing from your machine, create an API token on PyPI and use it as the
password with the username __token__.
If you publish through GitHub Actions, prefer PyPI Trusted Publishing instead of storing a long-lived API token.
The repository release workflow is documented in RELEASING.md.
Basic Usage
import mlx.core as mx
import diffct_mlx
# Create a 64x64 test image
image = mx.ones((64, 64), dtype=mx.float32)
# Generate parallel beam geometry (90 views)
ray_dir, det_origin, det_u_vec = diffct_mlx.circular_trajectory_2d_parallel(90)
# Forward projection → sinogram
sino = diffct_mlx.parallel_forward(
image, ray_dir, det_origin, det_u_vec,
num_detectors=92, detector_spacing=1.0, voxel_spacing=1.0
)
# Backprojection → reconstruction
reco = diffct_mlx.parallel_backward(
sino, ray_dir, det_origin, det_u_vec,
H=64, W=64, detector_spacing=1.0, voxel_spacing=1.0
)
Gradient Computation
Since all projectors are differentiable, you can compute gradients directly:
import mlx.core as mx
import diffct_mlx
def loss_fn(image):
ray_dir, det_origin, det_u_vec = diffct_mlx.circular_trajectory_2d_parallel(90)
sino = diffct_mlx.parallel_forward(image, ray_dir, det_origin, det_u_vec, 92)
return mx.sum(sino ** 2)
image = mx.ones((64, 64), dtype=mx.float32)
grad_fn = mx.grad(loss_fn)
gradient = grad_fn(image)
3D Cone Beam Example
import mlx.core as mx
import diffct_mlx
# Create a 32x64x64 volume (D, H, W)
volume = mx.ones((32, 64, 64), dtype=mx.float32)
# Generate cone beam geometry
src, det_c, det_u, det_v = diffct_mlx.circular_trajectory_3d(
n_views=60, sid=500.0, sdd=1000.0
)
# Forward projection
sino = diffct_mlx.cone_forward(
volume, src, det_c, det_u, det_v,
det_u=64, det_v=32, du=1.0, dv=1.0, voxel_spacing=1.0
)
# Backprojection
reco = diffct_mlx.cone_backward(
sino, src, det_c, det_u, det_v,
D=32, H=64, W=64, du=1.0, dv=1.0, voxel_spacing=1.0
)
API Reference
Projectors
| Function | Description |
|---|---|
parallel_forward(image, ray_dir, det_origin, det_u_vec, ...) |
2D parallel beam forward projection |
parallel_backward(sinogram, ray_dir, det_origin, det_u_vec, ...) |
2D parallel beam backprojection |
fan_forward(image, src_pos, det_center, det_u_vec, ...) |
2D fan beam forward projection |
fan_backward(sinogram, src_pos, det_center, det_u_vec, ...) |
2D fan beam backprojection |
cone_forward(volume, src_pos, det_center, det_u_vec, det_v_vec, ...) |
3D cone beam forward projection |
cone_backward(sinogram, src_pos, det_center, det_u_vec, det_v_vec, ...) |
3D cone beam backprojection |
Iterative Reconstruction Algorithms
The package now also exposes callback-based iterative reconstruction algorithms that are independent of geometry and dimensionality:
| Function | Description |
|---|---|
run_sart(...) |
SART with user-provided single-view forward/backprojectors |
run_tv_pocs(...) |
TV-POCS using the same projector callback pattern |
run_asd_pocs(...) |
ASD-POCS with adaptive TV step-size damping |
run_awtv_pocs(...) |
AwTV-POCS with edge-adaptive weighted TV regularization |
Each algorithm takes:
measured_projections: a list/sequence of per-view projectionsforward_project(volume, projection_index): user callback for one viewback_project(projection, projection_index): user callback for one viewReconstructionParameters: shared reconstruction settings- an algorithm-specific regularization dataclass where applicable
This keeps the algorithms reusable across parallel, fan, cone, 2D, and 3D setups as long as the caller provides the appropriate single-view projector wrappers.
Trajectory Generators
| Function | Geometry |
|---|---|
circular_trajectory_2d_parallel(n_views, ...) |
2D parallel |
sinusoidal_trajectory_2d_parallel(n_views, ...) |
2D parallel |
custom_trajectory_2d_parallel(n_views, ...) |
2D parallel |
circular_trajectory_2d_fan(n_views, sid, sdd, ...) |
2D fan |
sinusoidal_trajectory_2d_fan(n_views, sid, sdd, ...) |
2D fan |
custom_trajectory_2d_fan(n_views, sid, sdd, ...) |
2D fan |
circular_trajectory_3d(n_views, sid, sdd, ...) |
3D cone |
spiral_trajectory_3d(n_views, sid, sdd, ...) |
3D cone |
sinusoidal_trajectory_3d(n_views, sid, sdd, ...) |
3D cone |
saddle_trajectory_3d(n_views, sid, sdd, ...) |
3D cone |
random_trajectory_3d(n_views, sid_mean, sdd_mean, ...) |
3D cone |
custom_trajectory_3d(n_views, sid, sdd, ...) |
3D cone |
Examples
Ready-to-run scripts are provided in the examples/ directory:
Circular Trajectory (Analytical Reconstruction)
| Script | Description |
|---|---|
examples/circular_trajectory/fbp_parallel.py |
FBP with ramp filter — 2D parallel beam |
examples/circular_trajectory/fbp_fan.py |
FBP with cosine weighting + ramp filter — 2D fan beam |
examples/circular_trajectory/fdk_cone.py |
FDK with distance weighting + ramp filter — 3D cone beam |
Non-Circular Trajectory (Iterative Reconstruction)
| Script | Description |
|---|---|
examples/non_circular_trajectory/iterative_reco_parallel.py |
Gradient-based iterative reco — sinusoidal & custom wobble trajectories |
examples/non_circular_trajectory/iterative_reco_fan.py |
Gradient-based iterative reco — sinusoidal & custom elliptical trajectories |
examples/non_circular_trajectory/iterative_reco_cone.py |
Gradient-based iterative reco — spiral, sinusoidal, saddle & figure-8 trajectories |
Run any example with:
conda activate diffct-mlx
python examples/circular_trajectory/fbp_parallel.py
Package Structure
diffct_mlx/
├── __init__.py # Public API exports
├── constants.py # MLX-specific constants and dtypes
├── utils.py # Grid computation utilities
├── geometry.py # Trajectory generation functions
├── projectors.py # Differentiable projector functions with VJPs
└── kernels/
├── __init__.py
├── parallel_beam.py # Metal kernels for 2D parallel beam
├── fan_beam.py # Metal kernels for 2D fan beam
└── cone_beam.py # Metal kernels for 3D cone beam
Citation
If you use this library in your research, please cite:
@software{diffct2026,
author = {Yipeng Sun, Linda-Sophie Schneider},
title = {diffct_mlx: Differentiable CT for Apple Silicon},
year = 2026,
}
License
This project is licensed under the Apache 2.0 License — see the LICENSE file for details.
Acknowledgements
This project was highly inspired by:
- PYRO-NN
- geometry_gradients_CT
- MLX by Apple
Issues and contributions are welcome!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file diffct_mlx-1.0.1.tar.gz.
File metadata
- Download URL: diffct_mlx-1.0.1.tar.gz
- Upload date:
- Size: 52.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
231bb147e09a759d9c372d9f1e46179c427b4359f553b502468b319f4d1cffc9
|
|
| MD5 |
0c68bcd72b7462411099258d4a450cea
|
|
| BLAKE2b-256 |
502a39a782ad5322b4dd9a6be9a75f859b1234ec56c820c5999f134cb076eff3
|
Provenance
The following attestation bundles were made for diffct_mlx-1.0.1.tar.gz:
Publisher:
publish.yml on Linda-SophieSchneider/DiffCT-MLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
diffct_mlx-1.0.1.tar.gz -
Subject digest:
231bb147e09a759d9c372d9f1e46179c427b4359f553b502468b319f4d1cffc9 - Sigstore transparency entry: 1580428449
- Sigstore integration time:
-
Permalink:
Linda-SophieSchneider/DiffCT-MLX@e6a1437859da918c57d1a5c030cfdfa8acfa30eb -
Branch / Tag:
refs/tags/v1.0.1 - Owner: https://github.com/Linda-SophieSchneider
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@e6a1437859da918c57d1a5c030cfdfa8acfa30eb -
Trigger Event:
push
-
Statement type:
File details
Details for the file diffct_mlx-1.0.1-py3-none-any.whl.
File metadata
- Download URL: diffct_mlx-1.0.1-py3-none-any.whl
- Upload date:
- Size: 65.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0f21a698b1f0f2ec1e7baa0947390fdb8e16539c2c61bc9b8f56dd0393d11c6b
|
|
| MD5 |
2687d4fbd9988d1190deadceb27f1650
|
|
| BLAKE2b-256 |
a9af3a569a41b27fe94ab4a86449a1560eb1cf2b9690539bbb6b2b7810aa3359
|
Provenance
The following attestation bundles were made for diffct_mlx-1.0.1-py3-none-any.whl:
Publisher:
publish.yml on Linda-SophieSchneider/DiffCT-MLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
diffct_mlx-1.0.1-py3-none-any.whl -
Subject digest:
0f21a698b1f0f2ec1e7baa0947390fdb8e16539c2c61bc9b8f56dd0393d11c6b - Sigstore transparency entry: 1580428572
- Sigstore integration time:
-
Permalink:
Linda-SophieSchneider/DiffCT-MLX@e6a1437859da918c57d1a5c030cfdfa8acfa30eb -
Branch / Tag:
refs/tags/v1.0.1 - Owner: https://github.com/Linda-SophieSchneider
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@e6a1437859da918c57d1a5c030cfdfa8acfa30eb -
Trigger Event:
push
-
Statement type: