Skip to main content

Metal-accelerated 3D Gaussian Splatting for Apple Silicon

Project description

msplat

A 3D Gaussian Splatting training engine for Apple Silicon, built entirely on Metal. No external dependencies beyond system frameworks.

The entire training pipeline: projection, sorting, rasterization, SSIM loss, backward pass, Adam optimizer, and densification runs as fused Metal compute shaders.

The result is a self-contained engine that trains a full-resolution Mip-NeRF 360 scene in ~80 seconds and renders it at ~350 FPS on an M4 Max.

Python and Swift bindings are provided, as well as a standalone C++ CLI.

Why this exists

The original 3D Gaussian Splatting implementation is CUDA-only. Ports to other frameworks (gsplat, taichi-3dgs, etc.) still depend on PyTorch for autograd, optimizer state, and tensor management. This means ~2GB of framework overhead, Python GIL contention, and no straightforward path to native macOS/iOS integration.

Architecture

core/metal/msplat_metal.metal    ← Compute kernels
core/src/                        ← C++ training loop, dataset loaders, SSIM eval
core/include/                    ← MTensor (lightweight GPU tensor), Model, API headers
python/bindings.cpp              ← nanobind Python module
swift/Sources/Msplat/            ← Swift package (via C API bridge)
cli/msplat.cpp                   ← C++ CLI

Training pipeline (single iteration)

Each training step dispatches all work into one Metal command encoder:

Forward:
  project_and_sh_forward     ← fused 3D→2D projection + spherical harmonics
  prefix_sum + scatter       ← gaussian→tile intersection mapping
  bitonic_sort_per_tile      ← tile-local depth sort + inline data packing
  nd_rasterize_forward       ← per-pixel alpha compositing (16x16 tiles)
  ssim_h_fwd + ssim_v_fwd   ← separable 11-tap SSIM + L1 loss

Backward:
  ssim_h_bwd + ssim_v_bwd   ← separable SSIM gradient
  rasterize_backward         ← per-pixel backward compositing
  project_and_sh_backward    ← fused projection + SH VJP + SH Adam update
  fused_adam (×4 groups)     ← optimizer step (means, scales, quats, opacity)
  accumulate_grad_stats      ← gradient norms for densification

Key design decisions

Tile-local bitonic sort instead of global radix sort. Each 16x16 tile independently sorts its gaussians (up to 2048) in threadgroup shared memory. The sort kernel also packs per-gaussian data (xy, opacity, conic, color) inline, eliminating a separate scatter dispatch.

GPU-resident densification. The split/clone/cull cycle never leaves the GPU. Classification, growth, and compaction are all compute kernels operating on device buffers. No CPU readback of gradient statistics or gaussian counts.

Fused kernels. Projection and spherical harmonic evaluation share registers (avoid a device memory round-trip for world-space position). The backward pass recomputes 3D covariance from scales/quaternions on-the-fly rather than storing it. SH backward gradients are computed in registers and fed directly into Adam updates, eliminating a separate gradient buffer write/read cycle. The remaining four parameter groups use fused Adam dispatches.

Separable SSIM. The 11x11 Gaussian-weighted SSIM window decomposes into two 1D passes (horizontal then vertical), reducing per-pixel work from 121 to 22 multiply-adds. Forward and backward each take two kernels, using threadgroup shared memory for the intermediate statistics.

Depth-chunked rasterization. For tiles with extreme gaussian counts, the forward pass splits into 512-gaussian chunks with a merge kernel that reconstructs absolute transmittance. The backward pass uses precomputed prefix/suffix transmittance to avoid re-traversal.

Installation & Usage

Python

pip install msplat
import msplat

dataset = msplat.load_dataset("path/to/colmap/", eval_mode=True)
config = msplat.TrainingConfig(iterations=7000, num_downscales=0)
trainer = msplat.GaussianTrainer(dataset, config)

trainer.train(lambda s: print(f"step={s.iteration} splats={s.splat_count:,}"),
              callback_every=100)

trainer.export_ply("output.ply")
trainer.save_checkpoint("checkpoint.msplat")  # save/resume training
metrics = trainer.evaluate()
print(f"PSNR: {metrics['psnr']:.2f}  SSIM: {metrics['ssim']:.3f}")

# Render from arbitrary viewpoints
pose = dataset.camera_pose(0)   # (4, 4) cam-to-world matrix
img = trainer.render_from_pose(pose)  # numpy (H, W, 3) float32

Supported dataset formats: COLMAP, Nerfstudio, Polycam.

Type stubs (_core.pyi) are included for IDE autocompletion.

CLI

pip install msplat[cli]
msplat-train path/to/dataset -n 7000 --eval

Swift

Requires Xcode and CMake (brew install cmake).

// Package.swift
dependencies: [
    .package(url: "https://github.com/rayanht/msplat.git", from: "1.1.0")
]

Build the XCFramework (one-time, from repo root):

./scripts/build-xcframework.sh
import Msplat

let dataset = GaussianDataset(path: "path/to/colmap/", downscaleFactor: 4.0)
let trainer = GaussianTrainer(dataset: dataset)

for _ in 0..<1000 {
    let stats = trainer.step()
    print("step=\(stats.iteration) splats=\(stats.splatCount)")
}

trainer.exportPly(to: "output.ply")

// Render from arbitrary viewpoints
let pose = dataset.cameraPose(at: 0)  // [Float] cam-to-world matrix
let img = trainer.renderFromPose(camToWorld: pose)

C++ CLI

cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build -j
./build/msplat path/to/dataset -n 7000 --eval

Build from source

git clone https://github.com/rayanht/msplat.git && cd msplat

# Python
pip install -e .

# C++ CLI + static lib
cmake -B build -DCMAKE_BUILD_TYPE=Release && cmake --build build -j

# Swift XCFramework
./scripts/build-xcframework.sh
cd swift && swift build

Requires macOS 14+, Apple Silicon. No external dependencies.

Benchmarks

mipnerf360, M4 Max. msplat runs 7K iterations with no downscales:

msplat-train path/to/scene -n 7000 --num-downscales 0 --eval
Scene msplat PSNR msplat SSIM msplat wall time gsplat PSNR gsplat SSIM gsplat wall time
bicycle 23.20 0.604 65s 23.71 0.668 ~335s
counter 27.40 0.880 84s 27.14 0.878 ~335s
garden 25.74 0.785 89s 26.30 0.833 ~335s
room 29.96 0.897 78s 29.21 0.893 ~335s

30K iterations (garden)

msplat-train path/to/garden -n 30000 --num-downscales 0 --eval
msplat gsplat
PSNR 27.12 27.32
SSIM 0.853 0.865
Gaussians 3.49M
Wall time 804s ~2149s

gsplat numbers from docs.gsplat.studio (TITAN RTX). gsplat wall times are the reported average across all mipnerf360 scenes (per-scene times not published).

Performance history (wall time, M4 Max)

Scene v1.0 v1.1.3 Speedup
bicycle 7K 82s 65s 1.26x
counter 7K 91s 84s 1.08x
garden 7K 107s 89s 1.20x
room 7K 85s 78s 1.09x
garden 30K 1039s 804s 1.29x

v1.1.3 fuses SH backward gradients into Adam optimizer updates, eliminating ~600 MB/iter of device memory traffic at 1.5M gaussians. Speedup scales with gaussian count.

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

msplat-1.1.3-cp313-cp313-macosx_14_0_arm64.whl (654.8 kB view details)

Uploaded CPython 3.13macOS 14.0+ ARM64

msplat-1.1.3-cp312-cp312-macosx_14_0_arm64.whl (654.9 kB view details)

Uploaded CPython 3.12macOS 14.0+ ARM64

File details

Details for the file msplat-1.1.3-cp313-cp313-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for msplat-1.1.3-cp313-cp313-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 f3b850db044c7e10ee60b526febd9ebe2eb9d9093d609d253452efd8c4dec446
MD5 07b0a892385cb788a27f71b9ad42cc2e
BLAKE2b-256 8795e82e8f4993e06bcceea00e76c78d9028d955a6fc3c129a8f061137973707

See more details on using hashes here.

Provenance

The following attestation bundles were made for msplat-1.1.3-cp313-cp313-macosx_14_0_arm64.whl:

Publisher: release.yml on rayanht/msplat

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file msplat-1.1.3-cp312-cp312-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for msplat-1.1.3-cp312-cp312-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 86c7bd946ae4f7666a1a5a1c73bcf464c3f6c189275bd0cfc16202199bb2d650
MD5 a49bd8c88653dce956fd6f05fd9b2151
BLAKE2b-256 5af7ca37067c638c6bb4ec4a41a1606b35d61b8300f41d5c211d71428f060be3

See more details on using hashes here.

Provenance

The following attestation bundles were made for msplat-1.1.3-cp312-cp312-macosx_14_0_arm64.whl:

Publisher: release.yml on rayanht/msplat

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page