TensorWrap: A high level TensorFlow wrapper for JAX.
Project description
TensorWrap
TensorWrap - A full-fledged Deep Learning Library based on JAX and TensorFlow.
What is TensorWrap?
TensorWrap is high performance neural network library that acts as a wrapper around JAX (another high performance machine learning library), bringing the familiar feel of the TensorFlow (2.x.x) and PyTorch (1.x.x) for users. How
TensorWrap works by creating a layer of abstraction over JAX's low level api and introducing similar TensorFlow-like component's while supporting Autograd in native JAX operations. Additionally, the api has been updated to become simpler and more concise than TensorFlow's current API. Namespaces, internals, and various
This is a personal project, not professionally affliated with Google in any way. Expect bugs and several incompatibilities/difference from the original libraries. Please help by trying it out, reporting bugs, and letting us know what you think!
Contents
Examples
- Custom Layers
import tensorwrap as tf
from tensorwrap import keras
class Dense(keras.layers.Layer):
def __init__(self, units) -> None:
super().__init__() # Needed for making it JIT compatible.
self.units = units # Defining the output shape.
def build(self, input_shape: tuple) -> None:
super().build(input_shape) # Needed to check dimensions and build.
self.kernel = self.add_weights([input_shape[-1], self.units],
activation = 'glorot_uniform')
self.bias = self.add_weights([self.units],
activation = 'zeros')
# Use call not __call__ to define the flow. No tf.function needed either.
def call(self, inputs):
return inputs @ self.kernel + self.bias
- Custom Losses
import tensorwrap as tf
from tensorwrap import keras
class MSE(keras.losses.Loss):
def __init__(self):
pass
Current Gimmicks
-
Current models are all compiled by JAX's internal jit, so it won't be possible to view the actual internals of models, especially if it is a Sequential or Functional equations.
-
Also, using Module is currently not recommended, since other superclasses offer more functionality and ease of use.
Installation
The device installation of TensorWrap depends on its backend, being JAX. Thus, our normal install will be covering both the GPU and CPU installation.
pip install --upgrade pip
pip install --upgrade tensorwrap
On Linux, it is often necessary to first update pip
to a version that supports
manylinux2014
wheels. Also note that for Linux, we currently release wheels for x86_64
architectures only, other architectures require building from source. Trying to pip install with other Linux architectures may lead to jaxlib
not being installed alongside jax
, although jax
may successfully install (but fail at runtime).
These pip
installations do not work with Windows, and may fail silently; see
above.
Note
If any problems occur with cuda installation, please visit the JAX github page, in order to understand the problem with lower API installation.
Citations
This project have been heavily inspired by TensorFlow and once again, is built on the open-source machine learning XLA framework JAX. Therefore, I recongnize the authors of JAX and TensorFlow for the exceptional work they have done and understand that my library doesn't profit in any sort of way, since it is merely an add-on to the already existing community.
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
version = {0.3.13},
year = {2018},
}
Reference documentation
For details about the TensorWrap API, see the [main documentation] (coming soon!)
For details about JAX, see the reference documentation.
For documentation on TensorFlow API, see the API documentation
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file tensorwrap-0.0.0.3.tar.gz
.
File metadata
- Download URL: tensorwrap-0.0.0.3.tar.gz
- Upload date:
- Size: 15.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 20c79e1163436ffce29fede102cb6d3e4dffcc1ad1fada709f21bfe9468a139d |
|
MD5 | f83be492944e5b24b6d1b493c8792d58 |
|
BLAKE2b-256 | 983d18f240e957356af4c19e17791a8504abee0dc95c58cadfab8a1913ac7765 |