Skip to main content

Run PyTorch in JAX. 🤝

Project description

torch2jax

Run PyTorch in JAX. 🤝

Mix-and-match PyTorch and JAX code with seamless, end-to-end autodiff, use JAX classics like jit, grad, and vmap on PyTorch code, and run PyTorch models on TPUs.

torch2jax uses abstract interpretation (aka tracing) to move JAX values through PyTorch code. As a result, you get a JAX-native computation graph that follows exactly your PyTorch code, down to the last epsilon.

from torch2jax import j2t, t2j

vit = torchvision.models.vit_b_16().eval()
batch = torch.randn(1, 3, 224, 224)
vit(x)
# => [-5.3352e-01, ..., 2.0390e-01]

jax_vit = t2j(vit)
jax_batch = t2j(batch)
params = {k: t2j(v) for k, v in vit.named_parameters()}
jit(jax_vit)(jax_batch, state_dict=params)
# => [-5.3125e-01, ..., 2.0735e-01]

torch2jax even works with in-place PyTorch operations:

def f(x):
    x.add_(1)
    x.mul_(2)
    return x

f(torch.Tensor([3]))                # => torch.Tensor([8])

jax_f = t2j(f)
jax_f(jnp.array([3]))               # => jnp.array([8])
vmap(jax_f)(jnp.array([1, 2, 3]))   # => jnp.array([[4], [6], [8]])
grad(jax_f)(jnp.array([2.0]))       # => jnp.array([2.0])

torch2jax offers a simple API with two functions:

  1. j2t: Convert a JAX jax.numpy.ndarray to a torch.Tensor.
  2. t2j: Convert a PyTorch function, torch.nn.Module, or torch.Tensor to their JAX equivalent.

Internally, the core of torch2jax is Torchish, a class that mimics torch.Tensor via __torch_function__. A Torchish object is backed by a JAX jax.numpy.ndarray, and proxies PyTorch operations onto the underlying jax.numpy.ndarray. As a result, you get a JAX-native computation graph that exactly follows your PyTorch code.

Installation

PyPI

pip install torch2jax

Nix flake

torch2jax is available as a Nix flake.

$ nix shell github:samuela/torch2jax
(shell) $ python -c "from torch2jax import j2t, t2j"

FAQ

Help! I've encountered a PyTorch operation that isn't implemented yet.

torch2jax is an implementation of the PyTorch standard library written in JAX. If you come across an operation that isn't implemented yet, please file an issue and/or PR!

Adding new PyTorch operations is straightforward. Check the source for functions decorated with @implements to get started.

My PyTorch model includes dropout (or some other random operation), and does not work in training mode. Why?

JAX mandates deterministic randomness, while PyTorch does not. This leads to some API friction. torch2jax does not currently offer a means to bridge this gap. I have an idea for how to accomplish it. If this is important to you, please open an issue.

In the meantime, make sure to call .eval() on your torch.nn.Module before conversion.

My PyTorch model includes batch norm (or some other torch.nn.Module utilizing buffers), and does not work in training mode. What can I do?

Similar to the randomness story, PyTorch and JAX have different approaches to maintaining state. Operations like batch norm require maintaining running statistics. In PyTorch, this is accomplished via buffers.

torch2jax supports running batch norm models in eval()-mode. Just don't forget that you should avoid taking gradients w.r.t. buffers. For example,

rn18 = torchvision.models.resnet18(weights="DEFAULT").eval()
loss = lambda x: torch.sum(x ** 2)

batch = torch.randn(1, 3, 224, 224)
loss(rn18(batch)).backward()

parameters = {k: t2j(v) for k, v in rn18.named_parameters()}
buffers = {k: t2j(v) for k, v in rn18.named_buffers()}

jax_rn18 = t2j(rn18)
grad(lambda params, x: loss(jax_rn18(x, state_dict={**params, **buffers})))(parameters, t2j(batch))

I have an idea for how to implement buffers, including in training mode. If this is important to you, please open an issue.

I'm seeing slightly different numerical results between PyTorch and JAX. Is it a bug?

Floating point arithmetic is hard. There are a number of sources of divergence preventing bit-for-bit equivalence:

  1. torch2jax guarantees equivalence with PyTorch standard library functions in the mathematical sense, but not necessarily in their operational execution. This can lead to slight differences in results.
  2. The JAX/XLA and PyTorch compilers apply different optimizations and should be expected to rewrite computation graphs in exciting and unpredictable ways, potentially invoking different CUDA kernels.
  3. CUDA kernels can be non-deterministic, for example as a result of floating point addition being non-associative.

Also bear in mind that floating point errors compound, so larger models will experience increased divergence.

What about going the other way around? Running JAX code in PyTorch?

Check out jax2torch.

Contributing

PyTorch has a non-trivial API surface to cover. Contributions are welcome!

Run the test suite with pytest running in nix shell. Format imports with isort --dont-follow-links ..

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch2jax-0.0.1.tar.gz (13.8 kB view details)

Uploaded Source

Built Distribution

torch2jax-0.0.1-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file torch2jax-0.0.1.tar.gz.

File metadata

  • Download URL: torch2jax-0.0.1.tar.gz
  • Upload date:
  • Size: 13.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for torch2jax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 d2c6941d5776a37680acfddea04dd66b094094a6a3c66de118e9d4ac6db74946
MD5 8f758e1a21654f8abc42c1067c388635
BLAKE2b-256 5254ff1558db55b5b3e39c0547b28b3fc0032866e3cf3a67c463e7c1af96f902

See more details on using hashes here.

File details

Details for the file torch2jax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: torch2jax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for torch2jax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a1c9e1278971f6b67e30edae9ddef1f48e03c2febc4bf196ee02e33b13fad9a1
MD5 014fb746d00ee3ff6834d207422f883c
BLAKE2b-256 0fa368abb64e2f3e24471effc027a6b996393b31ed971f1dffcf5d24d7946dd3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page