Skip to main content

FluxFEM: A weak-form-centric differentiable finite element framework in JAX

Project description

PyPI version License: Apache-2.0 Python Version CI CI DOI

FluxFEM

A weak-form-centric differentiable finite element framework in JAX, where variational forms are treated as first-class, differentiable programs.

Examples and Features

Example 1: Diffusion Example 2: Neo Neohookean Hyper Elasticity
Diffusion-mms Neo-Hookean

Features

  • Built on JAX, enabling automatic differentiation with grad, jit, vmap, and related transformations.
  • Weak-form–centric API that keeps formulations close to code; weak forms are represented as expression trees and compiled into element kernels, enabling automatic differentiation of residuals, tangents, and objectives.
  • Two assembly approaches: tensor-based (scikit-fem–style) assembly and weak-form-based assembly.
  • Handles both linear and nonlinear analyses with AD in JAX.

Usage

This library provides two assembly approaches.

  • A tensor-based assembly, where trial and test functions are represented explicitly as element-level tensors and assembled accordingly (in the style of scikit-fem).
  • A weak-form-based assembly, where the variational form is written symbolically and compiled before assembly.

The two approaches are functionally equivalent and share the same element-level execution model, but they differ in how you author the weak form. The example below mirrors the paper's diffusion case and makes the distinction explicit with jnp.

Assembly Flow

All expressions are first compiled into an element-level evaluation plan, which operates on quadrature-point–major tensors. This plan is then executed independently for each element during assembly.

As a result, both assembly approaches:

  • use the same quadrature-major (q, a, i) data layout,
  • perform element-local tensor contractions,
  • and are fully compatible with JAX transformations such as jit, vmap, and automatic differentiation.

kernel-based assembly (explicit JIT units)

If you want to control JIT boundaries explicitly, build a JIT-compiled element kernel and pass it to space.assemble_*. The kernel must return the integrated element contribution (not the quadrature integrand).

import fluxfem as ff
import jax
import jax.numpy as jnp

space = ff.make_hex_space(mesh, dim=1, intorder=2)

# bilinear: kernel(ctx) -> (n_ldofs, n_ldofs)
ker_K = ff.make_element_bilinear_kernel(ff.diffusion_form, 1.0, jit=True)
K = space.assemble_bilinear_form(ff.diffusion_form, 1.0, kernel=ker_K)

# linear: kernel(ctx) -> (n_ldofs,)
def linear_kernel(ctx):
    integrand = ff.scalar_body_force_form(ctx, 2.0)
    wJ = ctx.w * ctx.test.detJ
    return (integrand * wJ[:, None]).sum(axis=0)

ker_F = jax.jit(linear_kernel)
F = space.assemble_linear_form(ff.scalar_body_force_form, 2.0, kernel=ker_F)

tensor-based vs weak-form-based (diffusion example)

tensor-based assembly

The tensor-based assembly provides an explicit, low-level formulation with element kernels written using jax.numpy.(jnp).

import fluxfem as ff
import jax.numpy as jnp

def diffusion_form(ctx: ff.FormContext, kappa):
    # ctx.test.gradN / ctx.trial.gradN: (n_qp, n_nodes, dim)
    # output tensor: (n_qp, n_nodes, n_nodes)
    return kappa * jnp.einsum("qia,qja->qij", ctx.test.gradN, ctx.trial.gradN)

space = ff.make_hex_space(mesh, dim=3, intorder=2)
params = ff.Params(kappa=1.0)
K_ts = space.assemble_bilinear_form(diffusion_form, params=params.kappa)

weak-form-based assembly

In the weak-form-based assembly, the variational formulation itself is the primary object. The expression below defines a symbolic computation graph, which is later compiled and executed at the element level.

import fluxfem as ff
import fluxfem.helpers_wf as h_wf

space = ff.make_hex_space(mesh, dim=3, intorder=2)
params = ff.Params(kappa=1.0)

# u, v are symbolic trial/test fields (weak-form DSL objects).
# u.grad / v.grad are symbolic nodes (expression tree), not numeric arrays.
# dOmega() is the integral measure; the whole expression is compiled before assembly.
form_wf = ff.BilinearForm.volume(
    lambda u, v, p: p.kappa * (v.grad @ u.grad) * h_wf.dOmega()
).get_compiled()

K_wf = space.assemble_bilinear_form(form_wf, params=params)

Linear Elasticity assembly (weak-form based assembly)

import fluxfem as ff
import fluxfem.helpers_wf as h_wf

space = ff.make_hex_space(mesh, dim=3, intorder=2)
D = ff.isotropic_3d_D(1.0, 0.3)

form_wf = ff.BilinearForm.volume(
    lambda u, v, D: h_wf.ddot(v.sym_grad, D @ u.sym_grad) * h_wf.dOmega()
).get_compiled()

K = space.assemble_bilinear_form(form_wf, params=D)

Neo-Hookean residual assembly (weak-form DSL)

Below is a Neo-Hookean hyperelasticity example written in weak form. The residual is expressed symbolically and compiled into element-level kernels executed per element. No manual derivation of tangent operators is required; consistent tangents (Jacobians) for Newton-type solvers are obtained automatically via JAX AD.

def neo_hookean_residual_wf(v, u, params):
    mu = params["mu"]
    lam = params["lam"]
    F = h_wf.I(3) + h_wf.grad(u)  # deformation gradient
    C = h_wf.matmul(h_wf.transpose(F), F)
    C_inv = h_wf.inv(C)
    J = h_wf.det(F)

    S = mu * (h_wf.I(3) - C_inv) + lam * h_wf.log(J) * C_inv
    dE = 0.5 * (h_wf.matmul(h_wf.grad(v), F) + h_wf.transpose(h_wf.matmul(h_wf.grad(v), F)))
    return h_wf.ddot(S, dE) * h_wf.dOmega()

autodiff + jit compile

You can differentiate through the solve and JIT compile the hot path. The inverse diffusion tutorial shows this pattern:

def loss_theta(theta):
    kappa = jnp.exp(theta)
    u = solve_u_jit(kappa, traction_true)
    diff = u[obs_idx_j] - u_obs[obs_idx_j]
    return 0.5 * jnp.mean(diff * diff)

solve_u_jit = jax.jit(solve_u)
loss_theta_jit = jax.jit(loss_theta)
grad_fn = jax.jit(jax.grad(loss_theta))

Documentation

Full documentation, tutorials, and API reference are hosted at this site.

SetUp

You can install FluxFEM either via pip or Poetry.

Supported Python Versions

FluxFEM supports Python 3.11–3.13:

Choose one of the following methods:

Using pip

pip install fluxfem

Using poetry

poetry add fluxfem

Acknowledgements

I acknowledge the open-source software, libraries, and communities that made this work possible.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fluxfem-0.1.7.tar.gz (126.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fluxfem-0.1.7-py3-none-any.whl (143.2 kB view details)

Uploaded Python 3

File details

Details for the file fluxfem-0.1.7.tar.gz.

File metadata

  • Download URL: fluxfem-0.1.7.tar.gz
  • Upload date:
  • Size: 126.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.11.14 Linux/6.11.0-1018-azure

File hashes

Hashes for fluxfem-0.1.7.tar.gz
Algorithm Hash digest
SHA256 edd3fe0077e53d9355824965761a5153adaa6ad2201733b2dffc2f38c1ed5199
MD5 d1937b6ab68891a409da22516f2b8f54
BLAKE2b-256 63df8e0e641469cfce67f7f444616375352a9ef18ebb3a6f4fc352a083439dd3

See more details on using hashes here.

File details

Details for the file fluxfem-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: fluxfem-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 143.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.11.14 Linux/6.11.0-1018-azure

File hashes

Hashes for fluxfem-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 957498c74b3e6a7506ea38e2799b0beefebccdbfbc111c7386a5053e72b72e1c
MD5 2b98d70c9a60a7ac66dca7c63595da08
BLAKE2b-256 17aea9f11d10ef272344a9c5d7673f84caa0367c6a8d1780b463a9cfb9c664f1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page