Skip to main content

FluxFEM: A weak-form-centric differentiable finite element framework in JAX

Project description

PyPI version License: Apache-2.0 Python Version CI CI DOI

FluxFEM

A weak-form-centric differentiable finite element framework in JAX, where variational forms are treated as first-class, differentiable programs.

Examples and Features

Example 1: Diffusion Example 2: Neo Neohookean Hyper Elasticity
Diffusion-mms Neo-Hookean

Features

  • Built on JAX, enabling automatic differentiation with grad, jit, vmap, and related transformations.
  • Weak-form–centric API that keeps formulations close to code; weak forms are represented as expression trees and compiled into element kernels, enabling automatic differentiation of residuals, tangents, and objectives.
  • Two assembly approaches: tensor-based (scikit-fem–style) assembly and weak-form-based assembly.
  • Handles both linear and nonlinear analyses with AD in JAX.
  • Optional PETSc/PETSc-shell solvers via petsc4py for scalable linear solves (add fluxfem[petsc]).
  • Contact interface support for penalty/constraint contact formulations, including one-to-many contact spaces (OneToManyContactSurfaceSpace.from_meshes) and KKT assembly utilities (assemble_contact_coupling_matrices, assemble_contact_kkt, solve_contact_kkt).

Usage

This library provides two assembly approaches.

  • A tensor-based assembly, where trial and test functions are represented explicitly as element-level tensors and assembled accordingly (in the style of scikit-fem).
  • A weak-form-based assembly, where the variational form is written symbolically and compiled before assembly.

The two approaches are functionally equivalent and share the same element-level execution model, but they differ in how you author the weak form. The example below mirrors the paper's diffusion case and makes the distinction explicit with jnp.

Assembly Flow

All expressions are first compiled into an element-level evaluation plan, which operates on quadrature-point–major tensors. This plan is then executed independently for each element during assembly.

As a result, both assembly approaches:

  • use the same quadrature-major (q, a, i) data layout,
  • perform element-local tensor contractions,
  • and are fully compatible with JAX transformations such as jit, vmap, and automatic differentiation.

kernel-based assembly (explicit JIT units)

If you want to control JIT boundaries explicitly, build a JIT-compiled element kernel and pass it to space.assemble. The kernel must return the integrated element contribution (not the quadrature integrand). For untagged raw kernels, pass kind=.

import fluxfem as ff
import jax
import jax.numpy as jnp

space = ff.make_hex_space(mesh, dim=1, intorder=2)

# bilinear: kernel(ctx) -> (n_ldofs, n_ldofs)
ker_K = ff.make_element_bilinear_kernel(ff.diffusion_form, 1.0, jit=True)
K = space.assemble(ff.diffusion_form, 1.0, kernel=ker_K)

# linear: kernel(ctx) -> (n_ldofs,)
def linear_kernel(ctx):
    integrand = ff.scalar_body_force_form(ctx, 2.0)
    wJ = ctx.w * ctx.test.detJ
    return (integrand * wJ[:, None]).sum(axis=0)

ker_F = jax.jit(linear_kernel)
F = space.assemble(ff.scalar_body_force_form, 2.0, kernel=ker_F)

tensor-based vs weak-form-based (diffusion example)

tensor-based assembly

The tensor-based assembly provides an explicit, low-level formulation with element kernels written using jax.numpy.(jnp).

import fluxfem as ff
import jax.numpy as jnp

@ff.kernel(kind="bilinear", domain="volume")
def diffusion_form(ctx: ff.FormContext, kappa):
    # ctx.test.gradN / ctx.trial.gradN: (n_qp, n_nodes, dim)
    # output tensor: (n_qp, n_nodes, n_nodes)
    return kappa * jnp.einsum("qia,qja->qij", ctx.test.gradN, ctx.trial.gradN)

space = ff.make_hex_space(mesh, dim=3, intorder=2)
params = ff.Params(kappa=1.0)
K_ts = space.assemble(diffusion_form, params=params.kappa)

weak-form-based assembly

In the weak-form-based assembly, the variational formulation itself is the primary object. The expression below defines a symbolic computation graph, which is later compiled and executed at the element level.

import fluxfem as ff
import fluxfem.helpers_wf as h_wf

space = ff.make_hex_space(mesh, dim=3, intorder=2)
params = ff.Params(kappa=1.0)

# u, v are symbolic trial/test fields (weak-form DSL objects).
# u.grad / v.grad are symbolic nodes (expression tree), not numeric arrays.
# dOmega() is the integral measure; the whole expression is compiled before assembly.
form_wf = ff.BilinearForm.volume(
    lambda u, v, p: p.kappa * (v.grad @ u.grad) * h_wf.dOmega()
).get_compiled()

K_wf = space.assemble(form_wf, params=params)

Linear Elasticity assembly (weak-form based assembly)

import fluxfem as ff
import fluxfem.helpers_wf as h_wf

space = ff.make_hex_space(mesh, dim=3, intorder=2)
D = ff.isotropic_3d_D(1.0, 0.3)

form_wf = ff.BilinearForm.volume(
    lambda u, v, D: h_wf.ddot(v.sym_grad, D @ u.sym_grad) * h_wf.dOmega()
).get_compiled()

K = space.assemble(form_wf, params=D)

Neo-Hookean residual assembly (weak-form DSL)

Below is a Neo-Hookean hyperelasticity example written in weak form. The residual is expressed symbolically and compiled into element-level kernels executed per element. No manual derivation of tangent operators is required; consistent tangents (Jacobians) for Newton-type solvers are obtained automatically via JAX AD.

def neo_hookean_residual_wf(v, u, params):
    mu = params["mu"]
    lam = params["lam"]
    F = h_wf.I(3) + h_wf.grad(u)  # deformation gradient
    C = h_wf.matmul(h_wf.transpose(F), F)
    C_inv = h_wf.inv(C)
    J = h_wf.det(F)

    S = mu * (h_wf.I(3) - C_inv) + lam * h_wf.log(J) * C_inv
    dE = 0.5 * (h_wf.matmul(h_wf.grad(v), F) + h_wf.transpose(h_wf.matmul(h_wf.grad(v), F)))
    return h_wf.ddot(S, dE) * h_wf.dOmega()

res_form = ff.ResidualForm.volume(neo_hookean_residual_wf).get_compiled()

autodiff + jit compile

You can differentiate through the solve and JIT compile the hot path. The inverse diffusion tutorial shows this pattern:

def loss_theta(theta):
    kappa = jnp.exp(theta)
    u = solve_u_jit(kappa, traction_true)
    diff = u[obs_idx_j] - u_obs[obs_idx_j]
    return 0.5 * jnp.mean(diff * diff)

solve_u_jit = jax.jit(solve_u)
loss_theta_jit = jax.jit(loss_theta)
grad_fn = jax.jit(jax.grad(loss_theta))

FESpace vs FESpacePytree

Use FESpace for standard workflows with a fixed mesh. When you need to carry the space through JAX transformations (e.g., shape optimization where mesh coordinates are part of the computation), use FESpacePytree via make_*_space_pytree(...). This keeps the mesh/basis in the pytree so jax.jit/jax.grad can see geometry changes.

Mixed systems

Mixed problems can be assembled from residual blocks and solved as a coupled system.

import fluxfem as ff
import jax.numpy as jnp

mixed = ff.MixedFESpace({"u": space_u, "p": space_p})
residuals = ff.make_mixed_residuals(
    u=res_u,  # (v, u, params) -> Expr
    p=res_p,  # (q, u, params) -> Expr
)
problem = ff.MixedProblem(mixed, residuals, params=ff.Params(alpha=1.0))

u0 = jnp.zeros(mixed.n_dofs)
R = problem.assemble_residual(u0)
J = problem.assemble_jacobian(u0, return_flux_matrix=True)

Block assembly

For constraints like contact problems (e.g., adding Lagrange multipliers), build a block matrix explicitly:

from fluxfem import solver as ff_solver

# Example blocks from contact coupling
K_uu = ...
K_cc = ...
K_uc = ...

blocks = ff_solver.make_block_matrix(
    diag=ff_solver.block_diag(order=("u", "c"), u=K_uu, c=K_cc),
    rel={("u", "c"): K_uc},
    symmetric=True,
    transpose_rule="T",
)

# Lazy container; assemble when you need the global matrix.
K = blocks.assemble()

FluxFEM also provides high-level contact utilities:

# Minimal one-to-many contact setup
contact = ff.OneToManyContactSurfaceSpace.from_meshes(
    master_mesh=mesh_master,
    slave_meshes=[mesh_slave],
    master_space=space_master,      # optional
    slave_spaces=[space_slave],     # optional
    master_facet_selector=select_master,
    slave_facet_selectors=[select_slave],
)

# 1) Assemble constraint operators (B, Kuu, ...)
lm_space = ff.ContactMultiplierSpace.from_contact(
    contact,
    family="p0",
    side="master",
)

ops: ff.ContactOperators = ff.assemble_contact_constraint_operators(
    contact,
    rho=1.0,
    multiplier=lm_space,
    backend="numpy",
    # Optional: also evaluate and store residual/jacobian metadata on the same ContactOperators.
    weak_form=contact_residual_form,
    state={"a": u_master, "b": u_slave},
    params=params,
)

# 2) Penalty-family path: user weak form -> residual/jacobian operators
ops_nitsche: ff.ContactOperators = ff.assemble_contact_penalty_operators(
    contact,
    weak_form=contact_residual_form,
    state={"a": u_master, "b": u_slave},
    params=params,
    backend="jax",
)

# 3) Unified coupled API (Penalty Family)
builder = ff.CoupledSystemBuilder.from_structural(K_u, F_u)
builder.register_blocks([
    ("master", space_master, {"value_dim": 1}),
    ("slave", space_slave, {"value_dim": 1}),
])
builder.add_contact(
    ops_nitsche,
    master="master",
    slave="slave",
    value_dim=1,
)
system = builder.build()
u = system.solve(dirichlet_dofs=dir_dofs, dirichlet_vals=0.0, format="csr")

# 4) Unified coupled API (Constraint Family): KKT assembly is internal to builder
builder_mortar = ff.CoupledSystemBuilder.from_structural(K_u, F_u)
builder_mortar.register_blocks([
    ("master", space_master, {"value_dim": 1}),
    ("slave", space_slave, {"value_dim": 1}),
])
builder_mortar.add_contact(
    ops,
    master="master",
    slave="slave",
    value_dim=1,
)
system_mortar = builder_mortar.build()

# Advanced: law/formulation can be set explicitly when needed.
# - law="one_sided_normal_frictionless"
# - formulation="multiplier" | "penalty_consistent"

Contact API boundaries (fixed terms):

  • contact: interface geometry/pairing/supermesh/quadrature.
  • multiplier: LM discretization (family, side, value_dim).
  • formulation: enforcement variant used for routing.
  • ops: assembled bundle passed to CoupledSystemBuilder.

Notes:

  • Multiple contacts can be added with different settings per builder.add_contact(...) call.
  • ContactMultiplierSpace(family="p0") currently supports side="master" only (implementation limitation).
  • See docs: Usage -> Contact API Boundaries.

Documentation

Full documentation, tutorials, and API reference are hosted at this site.

Tutorials

  • tutorials/linearelastic_tensile_bar.py (linear elasticity, weak-form assembly)
  • tutorials/neo_hookean_cantilever.py (nonlinear hyperelasticity)
  • tutorials/thermoelastic_bar_1d.py / tutorials/thermoelastic_bar_1d_mixed.py (thermoelastic coupling)
  • tutorials/contact_supported_box_by_pillars.py (large box supported by multiple small boxes via penalty contact + Dirichlet supports)
  • tutorials/petsc_shell_poisson_demo.py (PETSc shell solver integration; see also tutorials/petsc_shell_poisson_pmat_demo.py)

Setup

You can install FluxFEM either via pip or Poetry.

Supported Python Versions

FluxFEM supports Python 3.11–3.13:

Choose one of the following methods:

Using pip

pip install fluxfem
pip install "fluxfem[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Using poetry

poetry add fluxfem
poetry add fluxfem[cuda12]

PETSc Integration

Optional PETSc-based solvers are available via petsc4py. Enable with the extra:

pip install "fluxfem[petsc]"
or
pip install "fluxfem[petsc,cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry add fluxfem --extras "petsc"
or
poetry add "fluxfem[petsc,cuda12]"
or
poetry add fluxfem --extras "petsc" --extras "cuda12"

Note: you must match the petsc4py version to the PETSc version you have installed. The current FluxFEM extra pins petsc4py==3.24.4 (see [project.optional-dependencies]), so make sure your PETSc install is compatible with that petsc4py release, or override it to match your PETSc build.

GPU note: this repo currently tests CUDA via the cuda12 extra only. Other CUDA versions are not covered by CI and may require manual JAX installation.

Acknowledgements

I acknowledge the open-source software, libraries, and communities that made this work possible.

Citation

Reference to cite if you use LlamaIndex in a paper:

@software{Watanabe_FluxFEM_2026,
author = {Watanabe, Kohei},
doi = {10.5281/zenodo.18734689},
month = {2},
title = {{FluxFEM}},
url = {https://github.com/kevin-tofu/fluxfem},
year = {2026}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fluxfem-0.2.8.tar.gz (186.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fluxfem-0.2.8-py3-none-any.whl (211.1 kB view details)

Uploaded Python 3

File details

Details for the file fluxfem-0.2.8.tar.gz.

File metadata

  • Download URL: fluxfem-0.2.8.tar.gz
  • Upload date:
  • Size: 186.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.11.14 Linux/6.14.0-1017-azure

File hashes

Hashes for fluxfem-0.2.8.tar.gz
Algorithm Hash digest
SHA256 fdf4681f5a816f24277e2a765011228db33a6b1f455865e3655725fecbdb5daf
MD5 f3ca2f38d34ce46d5f1fc680e6fffb64
BLAKE2b-256 1a4cb22e5f7ee93c76ff995176181e1d8520dbf17a28e1ee271d3f0f2b199f1f

See more details on using hashes here.

File details

Details for the file fluxfem-0.2.8-py3-none-any.whl.

File metadata

  • Download URL: fluxfem-0.2.8-py3-none-any.whl
  • Upload date:
  • Size: 211.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.11.14 Linux/6.14.0-1017-azure

File hashes

Hashes for fluxfem-0.2.8-py3-none-any.whl
Algorithm Hash digest
SHA256 babcafb432a0fab33fa94d7c6eac04b7fcec3c8763ac1539a3104bf50eddf8bb
MD5 6a19b2f99b30e537bba2cbe994d790ba
BLAKE2b-256 edad7f499ca9f26e7a5b5c043eec41ad4d02d10e0c3bafe51b9474c74cd9fc4a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page