Effortlessly transform PyTorch functions to JAX
Project description
tojax
tojax is a powerful library that enables seamless translation of pure PyTorch functions and models to JAX, combining PyTorch's familiar API with JAX's performance advantages including XLA compilation and automatic differentiation.
Key Features
- Automatic Model Translation: Convert PyTorch models to JAX with a single function call
- Function-Level Translation: Translate individual PyTorch operations to JAX equivalents
- Tensor Compatibility: Use PyTorch-style tensor operations backed by JAX arrays
- In-Place Operation Support: Handle PyTorch's mutable semantics in JAX's immutable world
- Specialized Library Support: Built-in patches for E3NN and FairChem models
- Graph Translation: Convert PyTorch FX computation graphs to JAX functions
Installation
Using uv
uv add tojax
Using pip
pip install tojax
Environment
In general, the prebuilt binaries of JAX and PyTorch do not work well within the same environment if both are installed with CUDA. To avoid this issue, please install only one of them with CUDA. Most likely, you want JAX to have the CUDA bindings since that is where computations are executed with this library.
Examples
1. Function Translation
tojax automatically translates PyTorch functions to JAX equivalents:
import torch
from tojax import tojax
# Get JAX equivalent of a PyTorch function
jax_add = tojax(torch.add)
# Use with JAX arrays
import jax.numpy as jnp
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
result = jax_add(a, b) # Uses JAX implementation
2. Model Conversion
Convert entire PyTorch models to JAX functions:
import torch.nn as nn
from tojax import tojax
# Define a PyTorch model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.linear(x))
# Convert to JAX
model = SimpleModel()
jax_model = tojax(model)
# Use with JAX arrays
import jax.numpy as jnp
x = jnp.ones((32, 10))
output = jax_model(x)
3. JIT Compilation
Everything is JIT compatible:
import jax
import torch.nn as nn
from tojax import tojax
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.linear(x))
# Convert model
model = SimpleModel() # From earlier example
jax_model = tojax(model)
# JIT compile for performance
@jax.jit
def fast_inference(x):
return jax_model(x)
# Benchmark
import time
x = jnp.ones((1000, 10))
# First call compiles
start = time.time()
result = jax.block_until_ready(fast_inference(x))
compile_time = time.time() - start
# Subsequent calls are fast
start = time.time()
result = jax.block_until_ready(fast_inference(x))
runtime = time.time() - start
print(f"Compile time: {compile_time:.4f}s")
print(f"Runtime: {runtime:.6f}s")
5. Gradient Computation
You can use standard JAX transformations like jax.grad:
import jax
import jax.numpy as jnp
import torch.nn as nn
from tojax import tojax
# Define a simple model
class LinearModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
model = LinearModel()
jax_model = tojax(model)
# Define loss function
def loss_fn(x, y):
pred = jax_model(x)
return jnp.mean((pred - y) ** 2)
# Compute gradients
x = jnp.zeros((100, 2))
y = jnp.zeros((100, 1))
grad_fn = jax.grad(loss_fn)
gradients = grad_fn(x, y)
print(f"Gradient shape: {gradients.shape}")
6. Export
Importantly, the resulting JAX functions can be exported and loaded without having the original source code or weights.
import jax
import torch
from jax import export
from tojax import tojax
@tojax
def f(x):
return torch.pow(x, 2)
inp = jnp.array([1, 2, 3])
exported = export.export(jax.jit(f))(inp)
with open("exported_fn.jax", "wb") as f:
f.write(exported.serialize())
This even works with shape polymorphism if the original source code supports this
import jax
import torch
from jax import export
from tojax import tojax
@tojax
def f(x):
return torch.pow(x, 2)
poly_shape = export.symbolic_shape("batch_size")
exported = export.export(jax.jit(f))(jax.ShapeDtypeStruct(poly_shape, jnp.float32))
with open("exported_fn.jax", "wb") as f:
f.write(exported.serialize())
How does it work?
tojax works by swapping PyTorch function dispatches by equivalent JAX functions. Crucially, we only do this for operations that act on tensors that depend on the input of the function. This allows tojax to be very permissive and allow it to deal with the intertwined Python+PyTorch code that is frequently used for pre-processing.
import torch
from tojax import tojax
@tojax
def f(x):
# These are all executed by PyTorch and the result will be taken to JAX.
a = torch.arange(10)
a = torch.pow(a, 2)
# The following operations depends on the function input x, thus, they get translated to JAX.
y = torch.add(x, a)
z = torch.sin(y + x)
return y
f(jnp.zeros(()))
Limitations and When tojax Won't Work
While tojax handles most PyTorch code seamlessly, there are important limitations due to JAX's functional programming model and XLA compilation requirements.
Data-Dependent Control Flow
tojax will fail when your PyTorch code contains control flow that depends on tensor values (data-dependent control flow). This is because JAX requires all control flow to be traceable at compile time.
Examples That Won't Work
import torch
import torch.nn as nn
from tojax import tojax
# This will FAIL - conditional based on tensor value
class ProblematicModel(nn.Module):
def forward(self, x):
if x.sum() > 0: # Data-dependent condition
return x * 2
else:
return x * 3
# This will FAIL - loop with data-dependent bounds
def problematic_function(x):
result = x
for i in range(int(x[0])): # Loop bound depends on data
result = result + 1
return result
# This will FAIL - indexing with data-dependent values
def problematic_indexing(x, indices):
# Advanced indexing with computed indices
mask = x > 0.5
return x[mask] # Dynamic shape based on data
Examples That Work (Static Control Flow)
import torch
import torch.nn as nn
import jax.numpy as jnp
from tojax import tojax
# Static control flow - works fine
class StaticModel(nn.Module):
def __init__(self, use_layer=True):
super().__init__()
self.layer = nn.Linear(10, 10)
self.use_layer = use_layer
def forward(self, x):
if self.use_layer: # Condition based on static attribute
x = self.layer(x)
return x
# Fixed iteration count - works fine
def static_loop_function(x):
result = x
for i in range(5): # Fixed number of iterations
result = result * 2
return result
# Use jnp.where for conditional operations
def conditional_with_where(x):
# Use jnp.where instead of if/else on data
return torch.where(x > 0, x * 2, x * 3)
# Fixed-size operations work fine
def fixed_operations(x):
# All operations have predictable shapes
mean_pooled = x.mean(dim=-1)
reshaped = x.reshape(x.shape[0], -1)
return reshaped @ mean_pooled.unsqueeze(-1)
Dynamic Shapes
# Operations that create dynamic output shapes
def dynamic_filter(x, threshold):
return x[x > threshold] # Output size depends on data
# Use fixed-size operations with padding/masking
def fixed_size_filter(x, threshold, max_size):
mask = x > threshold
# Pad to fixed size and use mask for downstream operations
return torch.where(mask, x, 0)
Symbolic Shape Tracing and len()
When using symbolic shape tracing (e.g., jax.export with polymorphic shapes), use tensor.shape[0] instead of len(tensor). Python requires __len__ to return a concrete int, so len() cannot propagate symbolic dimensions.
# Will break symbolic shape tracing
def bad(x):
n = len(x) # returns a concrete int, raises an Exception
return x.reshape(n, -1)
# Works with symbolic shapes
def good(x):
n = x.shape[0] # preserves symbolic dimension
return x.reshape(n, -1)
Views
# Views always return copies in tojax, so the flat_view will not share the same data as tensor.
# This error is not raised since it is hard to notice.
def inplace_aliasing(tensor):
flat_view = tensor.view(-1)
tensor.add_(1.0)
return tensor, flat_view
Side effects
# We only translate pure functions and the JAX function will always compute the same result as the torch function call would have.
i = 1
def f(x):
nonlocal i
i += 1
return x + i
f(torch.zeros(())) # 1
f(torch.zeros(())) # 2
# A single increment during the first trace
jax_f = jax.jit(tojax)(f)
jax_f(jnp.zeros(())) # 3
jax_f(jnp.zeros(())) # 3
# Incrementing again
f(torch.zeros(())) # 4
Advanced Features
Custom Function Registration
Register your own PyTorch-to-JAX function mappings:
from tojax.functions import translates
import jax.numpy as jnp
import torch
@translates(torch.sin)
def my_jax_implementation(x):
return jnp.sin(x) * 10
Module Patching
Create patches for custom modules:
from tojax.patches import register_patch
import torch.nn as nn
@register_patch(MyCustomModule)
def patch_my_module(module):
# Modify module for JAX compatibility
module.some_incompatible_flag = False
return module
Tested Models
We have tested tojax on the following models:
Testing
Run the test suite:
# Using uv
uv run pytest
# Using pytest directly
pytest test/
License
This project is licensed under the Apache License, Version 2.0
Acknowledgements
- JAX for the underlying array library and transformations
- PyTorch for the deep learning framework we're translating from
- E3NN and E3NN-JAX for equivariant neural networks
- Flax for neural network components
- torch2jax for inspiration
Citation
If you use tojax in your research, please cite:
@software{tojax2026,
title={tojax},
author={Cusp AI},
year={2026},
url={https://github.com/cusp-ai-oss/tojax}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tojax-0.1.1.tar.gz.
File metadata
- Download URL: tojax-0.1.1.tar.gz
- Upload date:
- Size: 420.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.6 {"installer":{"name":"uv","version":"0.11.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a32fd18813de94fc0685b7f359b4fc8dc8d852d1bf0a953c5886fe2caed39035
|
|
| MD5 |
0d7f2b2759c734b22b29e18f6866cf4d
|
|
| BLAKE2b-256 |
476cbde3ed02bc7a39992a58158ac6e51b5889f54dd6c2feab6c4a6cee80bef5
|
File details
Details for the file tojax-0.1.1-py3-none-any.whl.
File metadata
- Download URL: tojax-0.1.1-py3-none-any.whl
- Upload date:
- Size: 42.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.6 {"installer":{"name":"uv","version":"0.11.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
106a4870944410b32c2472fbb56b8401ca09cb023728f2816609a184c80dc748
|
|
| MD5 |
cc5919fdb383f768e138d2cb5017caf4
|
|
| BLAKE2b-256 |
90750f8292b43ea974d100c8ba2338aac297456957aebaaee25f74e046826d2f
|