**thoad** (Torch High Order Automatic Differentiation) is a lightweight reverse-mode autodifferentiation engine written entirely in Python that works over PyTorch’s computational graph to compute **high order partial derivatives**. Unlike PyTorch’s native autograd - which is limited to first-order native partial derivatives - **thoad** is able to performantly propagate arbitray-order derivatives throughout the graph, enabling more advanced gradient-based computations.
Project description
$\text{{Py}{\color{#EE4C2C}T}{orch }{\color{#EE4C2C}H}{igh }{\color{#EE4C2C}O}{rder }{\color{#EE4C2C}A}{utomatic }{\color{#EE4C2C}D}{ifferentiation}}$
[!NOTE] This package is still in an experimental stage. It may exhibit unstable behavior or produce unexpected results, and is subject to possible minor structural modifications in the future.
Introduction
thoad is a lightweight reverse-mode autodifferentiation engine written entirely in Python that works over PyTorch’s computational graph to compute high order partial derivatives. Unlike PyTorch’s native autograd - which is limited to first-order native partial derivatives - thoad is able to performantly propagate arbitray-order derivatives throughout the graph, enabling more advanced gradient-based computations.
Key Features
- Python 3.12+: the package is developed using Python version 3.12 or higher.
- Built on PyTorch: the package uses PyTorch as its only dependency. It is compatible with +70 PyTorch operator backwards.
- High-Order Differentiation: the package can compute arbitrary-order partial derivatives - including cross derivatives.
- Adoption of the PyTorch Computational Graph: the package autodifferentiates over the graph generated by PyTorch controllers.
- Non-Scalar Differentiation: Unlike
torch.Tensor.backward, the package allows starting differentiation from non-scalar tensors. - Support for Backward Hooks: the package allows registering backward hooks for dynamic tuning of propagated high-order gradients.
- Diagonal Optimization: the package detects and avoids duplication of cross diagonal dimensions during back-propagation.
- Symmetry Optimization: the package avoids computation of derivative blocks unninformative due to Schwarz theorem.
Using the Package
thoad exposes two primary interfaces for computing high-order derivatives:
thoad.backward: a function-based interface that closely resemblestorch.Tensor.backward. It provides a quick way to compute high-order gradients without needing to manage an explicit controller object, but it offers only the core functionality (derivative computation and storage).thoad.Controller: a class-based interface that wraps the output tensor’s subgraph in a controller object. In addition to performing the same high-order backward pass, it gives access to advanced features such as fetching specific mixed partials, inspecting batch-dimension optimizations, overriding backward-function implementations, retaining intermediate partials, and registering custom hooks.
thoad.backward
The thoad.backward function computes high-order partial derivatives of a given output tensor and stores them in each leaf tensor’s .hgrad attribute.
Arguments:
-
tensor: A PyTorch tensor from which to start the backward pass. This tensor must require gradients and be part of a differentiable graph. -
order: A positive integer specifying the maximum order of derivatives to compute. -
gradient: A tensor with the same shape astensorto seed the vector-Jacobian product (i.e., custom upstream gradient). If omitted, the default is used. -
crossings: A boolean flag (default=False). If set toTrue, mixed partial derivatives (i.e., derivatives that involve more than one distinct leaf tensor) will be computed. -
groups: An iterable of disjoint groups of leaf tensors. Whencrossings=False, only those mixed partials whose participating leaf tensors all lie within a single group will be calculated. Ifcrossings=Trueandgroupsis provided, a ValueError will be raised (they are mutually exclusive). -
keep_batch: A boolean flag (default=False) that controls how output dimensions are organized in the computed gradients.-
When
keep_batch=False: Gradients are returned in a fully flattened form. Concretely, think of the gradient tensor as having:- A single “output” axis that lists every element of the original output tensor (flattened into one dimension).
- One axis per derivative order, each listing every element of the corresponding input (also flattened).
For an N-th order derivative of a leaf tensor with
input_numelelements and an output withoutput_numelelements, the gradient shape is:- Axis 1: indexes all
output_numeloutputs - Axes 2…(N+1): each indexes all
input_numelinputs
-
When
keep_batch=True: Gradients preserve both a flattened “output” axis and each original output dimension before any input axes. You can visualize it as:- Axis 1 flattens all elements of the output tensor (size =
output_numel). - Axes 2...(k+1) correspond exactly to each dimension of the output tensor (if the output was shape
(d1, d2, ..., dk), these axes have sizesd1,d2, ...,dk). - Axes (k+2)...(k+N+1) each flatten all
input_numelelements of the leaf tensor, one axis per derivative order.
However, if a particular output axis does not influence the gradient for a given leaf, that axis is not expanded and instead becomes a size-1 dimension. This means only those output dimensions that actually affect a particular leaf’s gradient “spread” into the input axes; any untouched axes remain as 1, saving memory.
- Axis 1 flattens all elements of the output tensor (size =
-
-
keep_schwarz: A boolean flag (default=False). IfTrue, symmetric (Schwarz) permutations are retained explicitly instead of being canonicalized/reduced—useful for debugging or inspecting non-reduced layouts.
Returns:
- An instance of
thoad.Controllerwrapping the same tensor and graph.
Executing Autodifferentiation
import torch
import thoad
from torch.nn import functional as F
### Normal PyTorch workflow
X = torch.rand(size=(10,15), requires_grad=True)
Y = torch.rand(size=(15,20), requires_grad=True)
Z = F.scaled_dot_product_attention(query=X, key=Y.T, value=Y.T)
### Call thoad backward
order = 2
thoad.backward(tensor=Z, order=order)
### Checks
# check derivative shapes
for o in range(1, 1 + order):
assert X.hgrad[o - 1].shape == (Z.numel(), *(o * tuple(X.shape)))
assert Y.hgrad[o - 1].shape == (Z.numel(), *(o * tuple(Y.shape)))
# check first derivatives (jacobians)
fn = lambda x, y: F.scaled_dot_product_attention(x, y.T, y.T)
J = torch.autograd.functional.jacobian(fn, (X, Y))
assert torch.allclose(J[0].flatten(), X.hgrad[0].flatten(), atol=1e-6)
assert torch.allclose(J[1].flatten(), Y.hgrad[0].flatten(), atol=1e-6)
# check second derivatives (hessians)
fn = lambda x, y: F.scaled_dot_product_attention(x, y.T, y.T).sum()
H = torch.autograd.functional.hessian(fn, (X, Y))
assert torch.allclose(H[0][0].flatten(), X.hgrad[1].sum(0).flatten(), atol=1e-6)
assert torch.allclose(H[1][1].flatten(), Y.hgrad[1].sum(0).flatten(), atol=1e-6)
thoad.Controller
The Controller class wraps a tensor’s backward subgraph in a controller object, performing the same core high-order backward pass as thoad.backward while exposing advanced customization, inspection, and override capabilities.
Instantiation
Use the constructor to create a controller for any tensor requiring gradients:
controller = thoad.Controller(tensor=GO) # takes graph output tensor
tensor: A PyTorchTensorwithrequires_grad=Trueand a non-Nonegrad_fn.
Properties
-
.tensor → TensorThe output tensor underlying this controller. Setter: Replaces the tensor (after validation), rebuilds the internal computation graph, and invalidates any previously computed gradients. -
.compatible → boolIndicates whether every backward function in the tensor’s subgraph has a supported high-order implementation. IfFalse, some derivatives may fall back or be unavailable. -
.index → Dict[Type[torch.autograd.Function], Type[ExtendedAutogradFunction]]A mapping from base PyTorchautograd.Functionclasses to thoad’sExtendedAutogradFunctionimplementations. Setter: Validates and injects your custom high-order extensions.
Core Methods
.backward(order, gradient=None, crossings=False, groups=None, keep_batch=False, keep_schwarz=False) → None
Performs the high-order backward pass up to the specified derivative order, storing all computed partials in each leaf tensor’s .hgrad attribute.
order(int > 0): maximum derivative order.gradient(Optional[Tensor]): custom upstream gradient with the same shape ascontroller.tensor.crossings(bool, defaultFalse): IfTrue, mixed partial derivatives across different leaf tensors will be computed.groups(Optional[Iterable[Iterable[Tensor]]], defaultNone): Whencrossings=False, restricts mixed partials to those whose leaf tensors all lie within a single group. Ifcrossings=Trueandgroupsis provided, a ValueError is raised.keep_batch(bool, defaultFalse): controls whether independent output axes are kept separate (batched) or merged (flattened) in stored/retrieved gradients.keep_schwarz(bool, defaultFalse): ifTrue, retains symmetric permutations explicitly (no Schwarz reduction).
.display_graph() → None
Prints a tree representation of the tensor’s backward subgraph. Supported nodes are shown normally; unsupported ones are annotated with (not supported).
.register_backward_hook(variables: Sequence[Tensor], hook: Callable) → None
Registers a user-provided hook to run during the backward pass whenever gradients for any of the specified leaf variables are computed.
variables(Sequence[Tensor]): Leaf tensors to monitor.hook(Callable[[Tuple[Tensor, Tuple[Shape, ...], Tuple[Indep, ...]], dict[AutogradFunction, set[Tensor]]], Tuple[Tensor, Tuple[Shape, ...], Tuple[Indep, ...]]]): Receives the current(Tensor, shapes, indeps)plus contextual info, and must return the modified triple.
.require_grad_(variables: Sequence[Tensor]) → None
Marks the given leaf variables so that all intermediate partials involving them are retained, even if not required for the final requested gradients. Useful for inspecting or re-using higher-order intermediates.
.fetch_hgrad(variables: Sequence[Tensor], keep_batch: bool = False, keep_schwarz: bool = False) → Tuple[Tensor, Tuple[Tuple[Shape, ...], Tuple[Indep, ...], VPerm]]
Retrieves the precomputed high-order partial corresponding to the ordered sequence of leaf variables.
variables(Sequence[Tensor]): the leaf tensors whose mixed partial you want.keep_batch(bool, defaultFalse): ifTrue, each independent output axis remains a separate batch dimension in the returned tensor; ifFalse, independent axes are distributed/merged into derivative dimensions.keep_schwarz(bool, defaultFalse): ifTrue, returns derivatives retaining symmetric permutations explicitly.
Returns a pair:
-
Gradient tensor: the computed partial derivatives, shaped according to output and input dimensions (respecting
keep_batch/keep_schwarz). -
Metadata tuple
- Shapes (
Tuple[Shape, ...]): the original shape of each leaf tensor. - Indeps (
Tuple[Indep, ...]): for each variable, indicates which output axes remained independent (batch) vs. which were merged into derivative axes. - VPerm (
Tuple[int, ...]): a permutation that maps the internal derivative layout to the requestedvariablesorder.
- Shapes (
Use the combination of independent-dimension info and shapes to reshape or interpret the returned gradient tensor in your workflow.
Executing Autodifferentiation
import torch
import thoad
from torch.nn import functional as F
### Normal PyTorch workflow
X = torch.rand(size=(10,15), requires_grad=True)
Y = torch.rand(size=(15,20), requires_grad=True)
Z = F.scaled_dot_product_attention(query=X, key=Y.T, value=Y.T)
### Instantiate thoad controller and call backward
order = 2
controller = thoad.Controller(tensor=Z)
controller.backward(order=order, crossings=True)
### Fetch Partial Derivatives
# fetch T0 and T1 2nd order derivatives
partial_XX, _ = controller.fetch_hgrad(variables=(X, X))
partial_YY, _ = controller.fetch_hgrad(variables=(Y, Y))
assert torch.allclose(partial_XX, X.hgrad[1])
assert torch.allclose(partial_YY, Y.hgrad[1])
# fetch cross derivatives
partial_XY, _ = controller.fetch_hgrad(variables=(X, Y))
partial_YX, _ = controller.fetch_hgrad(variables=(Y, X))
[!TIP] A more detailed user guide with examples and feature walkthroughs is available in the notebook: https://github.com/mntsx/thoad/blob/master/examples/user_guide.ipynb
More About the Package
Future Plans
The following outlines the planned future developments and improvements for thoad:
-
Extend Backward Functionality
Develop further backprop capabilities to improve PyTorch integration supporting a broad subset of the most commonly used controllers. -
Advanced Optimization Framework
Build an optimization module inspired by the design oftorch.optim, with full support for higher-order gradients and flexible optimizer composition. -
PyTorch Integration
It would be exciting to eventually fully-integrate the package into the PyTorch framework, although it's unlikely to happen, since ensuring its coordinated stability would require relevant adjustments to the mentioned library. Specifically:- Providing it with a more comprehensive tool for accessing controllers' contextual information.
- Improving the accessibility to the type signatures of the backward functions.
License
This project is licensed under the MIT License.
See the LICENSE file for details.
PyTorch is distributed under the BSD 3-Clause License.
See PyTorch’s own LICENSE file for its full terms.
How to Cite
If you use thoad in your work, please consider citing it with the following BibTeX entry:
@Misc{thoad2025,
title = {thoad: PyTorch High Order Reverse-Mode Auto-Differentiation},
howpublished = {\url{https://github.com/mntsx/thoad}},
year = {2025}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file thoad-0.1.0.tar.gz.
File metadata
- Download URL: thoad-0.1.0.tar.gz
- Upload date:
- Size: 168.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e0cf1d9da6cfdaf864c2256393ca75a672562cd31fec8ffdd0074bae990d8e0b
|
|
| MD5 |
81752fcd854e4c6a2022da04ad96a88d
|
|
| BLAKE2b-256 |
b4abe82f832059af2ed8737068a57be244694084766d14c733f39a2b62321a44
|
File details
Details for the file thoad-0.1.0-py3-none-any.whl.
File metadata
- Download URL: thoad-0.1.0-py3-none-any.whl
- Upload date:
- Size: 205.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
68d303719c454b2f8f9b80d2287a92b8f853d123d73eaa55b61f5248102b7613
|
|
| MD5 |
d41799bc0c03aebebebd522ee5c45be3
|
|
| BLAKE2b-256 |
bf709d256e1ce327cf0c3f1e73d8ca35b91629e8cc5b6bc6b2f85894fa02571a
|