Skip to main content

TensorWrap: A high level TensorFlow wrapper for JAX.

Project description

TensorWrap

TensorWrap - A full-fledged Deep Learning Library based on JAX and TensorFlow.

PyPI version

| Install guide

What is TensorWrap?

TensorWrap is high performance neural network library that acts as a wrapper around JAX (another high performance machine learning library), bringing the familiar elements of the TensorFlow (2.x.x). This is currently aimed towards prototyping over deployment, in the current state.

TensorWrap works by creating a layer of abstraction over JAX's low level api and introducing similar TensorFlow-like component's while supporting its own explicit and magic free design philosophy. This allows TensorWrap to be fast and efficient, while remaining nearly fully compatible with all custom operations and other tools from the JAX ecosystem. Additionally, this library adds additional features and leverages JAX's optimizations, making it more friendly towards research and educational audiences.

This is a personal project, not professionally affliated with Google in any way. Expect bugs and several incompatibilities/difference from the original libraries. Please help by trying it out, reporting bugs, and letting me know what you think!

Contents

Examples

  1. Custom Layers
import tensorwrap as tf
from tensorwrap import nn

class Dense(nn.layers.Layer):
    def __init__(self, units) -> None:
        super().__init__() # Needed for tracking trainable_variables.
        self.units = units # Defining the output shape
  
    def build(self, input_shape: tuple) -> None:
        super().build() # Required for letting model know that layer is built.
        input_shape = tf.shape(input_shape) # Getting appropriate input shape
        
        # Naming each parameter to later access from model.trainable_variables
        self.kernel = self.add_weights([input_shape, self.units],
                                       initializer = 'glorot_uniform',
                                       name='kernel')
        self.bias = self.add_weights([self.units],
                                     initializer = 'zeros',
                                     name='bias')
        
    
    # Use call not __call__ to define the flow. To support JIT compilation, we use staticmethod.
    @staticmethod
    @tf.function
    def call(params, inputs):
        return inputs @ params['kernel'] + params['bias'] # Using params as an input, allows use to pass in the model.trainable_variables later.
  1. Just In Time Compiling with tf.function
import tensorwrap as tf
from tensorwrap import nn
tf.test.is_device_available(device_type = 'cuda')

@tf.function
def mse(y_pred, y_true):
    return tf.mean(tf.square(y_pred - y_true))

print(mse(100, 102))
  1. Custom Models
import tensorwrap as tf
from tensorwrap import nn

class Sequential(nn.Model):
    def __init__(self, layers: list) -> None:
        super().__init__(name = "Sequential") # Starts the tracking of internal variables. Allows for name definition.
        self.layers = layers

    def __call__(self, inputs):
        x = inputs
        for layer in self.layers:
            x = layer(x)
        return x

model = Sequential([
    nn.layers.Dense(100),
    nn.layers.Dense(10)
])

Current Gimmicks

  1. Current models are all compiled by JAX's internal jit, so any error may remain a bit more cryptic than PyTorchs. However, this problem is still being worked on.

  2. Also, using tensorwrap.Module is currently not recommended, since other superclasses offer more functionality and ease of use.

  3. Graph execution is currently not available, which means that all exported models can only be deployed within a python environment.

Installation

The device installation of TensorWrap depends on its backend, being JAX. Thus, our normal install will be covering only the cpu version. For gpu version, please check JAX's documentation.

pip install --upgrade pip
pip install --upgrade tensorwrap

On Linux, it is often necessary to first update pip to a version that supports manylinux2014 wheels. Also note that for Linux, we currently release wheels for x86_64 architectures only, other architectures require building from source. Trying to pip install with other Linux architectures may lead to jaxlib not being installed alongside jax, although jax may successfully install (but fail at runtime). These pip installations do not work with Windows, and may fail silently; see above.

Note

If any problems occur with cuda installation, please visit the JAX github page, in order to understand the problem with lower API installation.

Citations

This project have been heavily inspired by TensorFlow and once again, is built on the open-source machine learning XLA framework JAX. Therefore, I recognize the authors of JAX and TensorFlow for the exceptional work they have done and understand that my library doesn't profit in any sort of way, since it is merely an add-on to the already existing community.

@software{jax2018github,
  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
  url = {http://github.com/google/jax},
  version = {0.3.13},
  year = {2018},
}

Reference documentation

For details about the TensorWrap API, see the [main documentation] (coming soon!)

For details about JAX, see the reference documentation.

For documentation on TensorFlow API, see the API documentation

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensorwrap-0.0.1.4.tar.gz (22.3 kB view details)

Uploaded Source

Built Distribution

tensorwrap-0.0.1.4-py3-none-any.whl (26.5 kB view details)

Uploaded Python 3

File details

Details for the file tensorwrap-0.0.1.4.tar.gz.

File metadata

  • Download URL: tensorwrap-0.0.1.4.tar.gz
  • Upload date:
  • Size: 22.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.4

File hashes

Hashes for tensorwrap-0.0.1.4.tar.gz
Algorithm Hash digest
SHA256 158eee2a462bbf454d0ef8ee65f2023599e32b2c36464b37aef293fb89c3ab1c
MD5 591e53421b1a5992ea36d23268329304
BLAKE2b-256 8210657aa31db5fef79b24fe47d4fb3aac16cf16de2c284075ab49d55bb417a1

See more details on using hashes here.

File details

Details for the file tensorwrap-0.0.1.4-py3-none-any.whl.

File metadata

  • Download URL: tensorwrap-0.0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 26.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.4

File hashes

Hashes for tensorwrap-0.0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 3ff0b24e86c58a06d3456a5911d10b57ec999e7fe203941e2eff8d0de7923a0c
MD5 be8dbb6810fed62e08640c8bdaad323c
BLAKE2b-256 4401a246a30a5a534c6fcc5ea4a9ef88fd437077f9ae3eda68c0ffa14232f8f7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page