Skip to main content

Putting TensorFlow back in PyTorch, back in Tensorflow (differentiable TensorFlow PyTorch adapters).

Project description

TfPyTh

Build Status codecov

Putting TensorFlow back in PyTorch, back in TensorFlow (differentiable TensorFlow PyTorch adapters).

A light-weight differentiable adapter library to make TensorFlow and PyTorch interact.

Install

pip install tfpyth

Example

import tensorflow as tf
import torch as th
import numpy as np
import tfpyth

session = tf.Session()

def get_torch_function():
    a = tf.placeholder(tf.float32, name='a')
    b = tf.placeholder(tf.float32, name='b')
    c = 3 * a + 4 * b * b

    f = tfpyth.torch_from_tensorflow(session, [a, b], c).apply
    return f

f = get_torch_function()
a = th.tensor(1, dtype=th.float32, requires_grad=True)
b = th.tensor(3, dtype=th.float32, requires_grad=True)
x = f(a, b)

assert x == 39.

x.backward()

assert np.allclose((a.grad, b.grad), (3., 24.))

What it's got

torch_from_tensorflow

Creates a PyTorch function that is differentiable by evaluating a TensorFlow output tensor given input placeholders.

eager_tensorflow_from_torch

Creates an eager Tensorflow function from a PyTorch function.

tensorflow_from_torch

Creates a TensorFlow op/tensor from a PyTorch function.

Future work

  • support JAX
  • support higher-order derivatives

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

tfpyth-1.0.1-py3-none-any.whl (4.2 kB view details)

Uploaded Python 3

File details

Details for the file tfpyth-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: tfpyth-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 4.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.7.3

File hashes

Hashes for tfpyth-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8d402bf051edfc4ac2777ff6bb88b5daafb867cf98839937b39b987b21141466
MD5 6cda0cedd9c7d32be3966d3141dff934
BLAKE2b-256 6d6c2adef3abd6923846b6557c13a9425f838aec05779541d14bf09c2118121c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page