Differentiable JAX-native magnetic field modeling with Magpylib-compatible APIs
Project description
magpylib_jax
Differentiable magnetic field modeling in JAX with Magpylib-compatible APIs, parity gates, and profiling/benchmark CI. magpylib_jax is designed for optimization, inverse design, and simulation pipelines that need Magpylib-style ergonomics together with jax.jit, jax.grad, jax.jacrev, and XLA compilation.
Why this project exists
magpylib_jax targets the gap between two requirements that are usually in tension:
- you want the closed-form, geometry-specific magnetic field models that make Magpylib useful,
- you also want a field pipeline that can be differentiated, compiled, and embedded in outer optimization loops.
This repository keeps the high-level user model close to upstream Magpylib while replacing the numerical core with JAX-first implementations and a JIT-safe field path.
What you get
- End-to-end differentiable
getB/getH/getJ/getM - JIT-safe high-level field evaluation by default
- Magpylib-style object API:
Collection,Sensor, path/orientation semantics, squeeze behavior - Analytical kernels for dipoles, loops, line currents, polygonal current sheets, and permanent magnets
- Parity gates against upstream Magpylib, including mirrored upstream test categories
- CI/CD coverage for lint, typing, docs, parity, benchmarks, profiling, and PyPI release builds
- Python support from
3.10onward - Unpinned core package dependencies in
pyproject.toml
Implemented source families
misc.Dipolecurrent.Circlecurrent.Polylinecurrent.TriangleSheetcurrent.TriangleStripmisc.Trianglemagnet.Cuboidmagnet.Cylindermagnet.CylinderSegmentmagnet.Spheremagnet.Tetrahedronmagnet.TriangularMesh
Installation
pip install magpylib-jax
For local development, tests, and docs:
python -m venv .venv
source .venv/bin/activate
pip install -e '.[test,docs]'
pytest
For GPU-backed environments, install the appropriate jax/jaxlib build for your platform first, then install magpylib-jax.
Quick example
import jax
import jax.numpy as jnp
import magpylib_jax as mpj
jax.config.update("jax_enable_x64", True)
src = mpj.magnet.CylinderSegment(
polarization=(0.1, -0.2, 0.3),
dimension=(0.4, 1.2, 1.1, -30.0, 110.0),
)
obs = jnp.array([1.2, 0.2, 0.4])
B = src.getB(obs)
def bz(r2):
trial = mpj.magnet.CylinderSegment(
polarization=(0.1, -0.2, 0.3),
dimension=(0.4, r2, 1.1, -30.0, 110.0),
)
return trial.getB(obs)[2]
print(B)
print(jax.grad(bz)(1.2))
Documentation map
- Overview: scope, supported objects, validation strategy, architectural intent
- Quickstart: install, first field computation, first gradient, troubleshooting
- Equation Models: field conventions, model equations, geometric reductions, derivation notes
- Numerics: stability, masking, singular behavior, precision, differentiation notes
- Examples: object API, functional API, optimization loops, performance workflows
- Architecture and Source Map: clickable source-code guide to the repository internals
- Testing and Validation: CI/CD gates, parity strategy, coverage, compatibility matrix
- Performance: profiling workflow, hotspot kernels, JIT entrypoints, memory behavior
- Parity Strategy
- Parity Checklist
- API Reference
- Changelog
JIT-safe getB
magpylib_jax.getB/getH/getJ/getM runs through a JIT-safe core by default. That path preserves Magpylib-style behavior while making the computational graph usable inside larger JAX programs.
Important notes:
output="dataframe"is supported for compatibility, but is intentionally outside JIT.pixel_aggreducers supportmean,sum,min, andmaxon the JIT-safe path.- Repeated object evaluations reuse preparation caches for sources, sensors, orientation matrices, collection flattening,
TriangularMeshgeometry, andCylinderSegmentface geometry. - Circle-heavy workloads use a dedicated fast path to reduce host overhead and memory pressure.
- For benchmark-quality timing, use
jax.block_until_ready(...)around the result.
Differentiable fitting example
import jax
import jax.numpy as jnp
import magpylib_jax as mpj
obs = jnp.array([[0.2, 0.1, 0.4], [0.5, 0.0, 0.7]])
target = jnp.array([[2.0e-4, 0.0, 3.0e-4], [1.0e-4, 0.0, 2.0e-4]])
def loss_fn(pol):
src = mpj.magnet.Cuboid(dimension=(1.0, 0.8, 1.2), polarization=pol)
pred = src.getB(obs)
return jnp.mean((pred - target) ** 2)
pol = jnp.array([0.05, -0.02, 0.08])
for _ in range(50):
pol = pol - 1e-1 * jax.grad(loss_fn)(pol)
Validation and release gates
CI enforces:
- lint and type checks,
- docs build,
>=90%coverage,- sharded
pytest -m 'not slow'test coverage, - benchmark regression thresholds,
- profiling regression thresholds,
- Python compatibility checks on
3.10,3.12, and3.13.
Nightly workflows additionally run the full validation suite and extended profiling artifact generation.
Key repository files
PARITY_MATRIX.mdMIGRATION_PLAN.mdpyproject.tomlbenchmarks/thresholds.jsonprofiling/thresholds.jsonprofiling/hlo_baseline.json.readthedocs.yaml.github/workflows/publish-pypi.yml
License
BSD-2-Clause.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file magpylib_jax-1.0.1.tar.gz.
File metadata
- Download URL: magpylib_jax-1.0.1.tar.gz
- Upload date:
- Size: 110.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1134d42eaada0f630fee0e4c0dba6c366edb93a015ed517c0145cd6770778367
|
|
| MD5 |
c262ab0ecac20f5e8d21e5e3050d763d
|
|
| BLAKE2b-256 |
dda6142f0d3466eb62da92aa5d634ae3936675c17a378a5fd4bd6abb4bafbc06
|
Provenance
The following attestation bundles were made for magpylib_jax-1.0.1.tar.gz:
Publisher:
publish-pypi.yml on uwplasma/magpylib_jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
magpylib_jax-1.0.1.tar.gz -
Subject digest:
1134d42eaada0f630fee0e4c0dba6c366edb93a015ed517c0145cd6770778367 - Sigstore transparency entry: 1343318814
- Sigstore integration time:
-
Permalink:
uwplasma/magpylib_jax@12eb63658f5b2f6588d471492822fc69af449017 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/uwplasma
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-pypi.yml@12eb63658f5b2f6588d471492822fc69af449017 -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file magpylib_jax-1.0.1-py3-none-any.whl.
File metadata
- Download URL: magpylib_jax-1.0.1-py3-none-any.whl
- Upload date:
- Size: 70.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
64beb632ff43f0e282c726bfd9a3b85a4b70d73d14d2cefe794a604d48534393
|
|
| MD5 |
eaacc2dd38516afc959b5e0dc72b19ed
|
|
| BLAKE2b-256 |
98b92a4dae8ea9599f518d45ff68c816a0c3f638380fb9ca0362de6bbecead8d
|
Provenance
The following attestation bundles were made for magpylib_jax-1.0.1-py3-none-any.whl:
Publisher:
publish-pypi.yml on uwplasma/magpylib_jax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
magpylib_jax-1.0.1-py3-none-any.whl -
Subject digest:
64beb632ff43f0e282c726bfd9a3b85a4b70d73d14d2cefe794a604d48534393 - Sigstore transparency entry: 1343318825
- Sigstore integration time:
-
Permalink:
uwplasma/magpylib_jax@12eb63658f5b2f6588d471492822fc69af449017 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/uwplasma
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-pypi.yml@12eb63658f5b2f6588d471492822fc69af449017 -
Trigger Event:
workflow_dispatch
-
Statement type: