Skip to main content

Putting TensorFlow back in PyTorch, back in Tensorflow (differentiable TensorFlow PyTorch adapters).

Project description

TfPyTh

Build Status codecov

Putting TensorFlow back in PyTorch, back in TensorFlow (differentiable TensorFlow PyTorch adapters).

A light-weight differentiable adapter library to make TensorFlow and PyTorch interact.

Install

pip install tfpyth

Example

import tensorflow as tf
import torch as th
import numpy as np
import tfpyth

session = tf.Session()

def get_torch_function():
    a = tf.placeholder(tf.float32, name='a')
    b = tf.placeholder(tf.float32, name='b')
    c = 3 * a + 4 * b * b

    f = tfpyth.torch_from_tensorflow(session, [a, b], c).apply
    return f

f = get_torch_function()
a = th.tensor(1, dtype=th.float32, requires_grad=True)
b = th.tensor(3, dtype=th.float32, requires_grad=True)
x = f(a, b)

assert x == 39.

x.backward()

assert np.allclose((a.grad, b.grad), (3., 24.))

What it's got

torch_from_tensorflow

Creates a PyTorch function that is differentiable by evaluating a TensorFlow output tensor given input placeholders.

eager_tensorflow_from_torch

Creates an eager Tensorflow function from a PyTorch function.

tensorflow_from_torch

Creates a TensorFlow op/tensor from a PyTorch function.

Future work

  • support JAX
  • support higher-order derivatives

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

tfpyth-1.0.1-py3-none-any.whl (4.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page