Skip to main content

TensorGuard helps to guard against bad Tensor Shapes

Project description

Tensor Guard

PyPI version fury.io PyPI pyversions PyPI download month GitHub followers

TensorGuard helps to guard against bad Tensor shapes in any tensor based library (e.g. Numpy, Pytorch, Tensorflow) using an intuitive symbolic-based syntax

Installation

pip install tensorguard

Basic Usage

import numpy as np  # could be tensorflow or torch as well
import tensorguard as tg

# tensorguard = tg.TensorGuard()  #could be done in a OOP fashion
img = np.ones([64, 32, 32, 3])
flat_img = np.ones([64, 1024])
labels = np.ones([64])

# check shape consistency
tg.guard(img, "B, H, W, C")
tg.guard(labels, "B, 1")  # raises error because of rank mismatch
tg.guard(flat_img, "B, H*W*C")  # raises error because 1024 != 32*32*3

# guard also returns the tensor, so it can be inlined
mean_img = tg.guard(np.mean(img, axis=0), "H, W, C")

# more readable reshapes
flat_img = tg.reshape(img, 'B, H*W*C')

# evaluate templates
assert tg.get_dims('H, W*C+1') == [32, 97]

Shape Template Syntax

The shape template mini-DSL supports many different ways of specifying shapes:

  • numbers: "64, 32, 32, 3"
  • named dimensions: "B, width, height2, channels"
  • wildcards: "B, *, *, *"
  • ellipsis: "B, ..., 3"
  • addition, subtraction, multiplication, division: "B*N, W/2, H*(C+1)"
  • dynamic dimensions: "?, H, W, C" (only matches [None, H, W, C])

Original Repo link: https://github.com/Qwlouse/shapeguard

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tensorguard-1.0.0.tar.gz (27.4 kB view details)

Uploaded Source

File details

Details for the file tensorguard-1.0.0.tar.gz.

File metadata

  • Download URL: tensorguard-1.0.0.tar.gz
  • Upload date:
  • Size: 27.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/53.1.0 requests-toolbelt/0.9.1 tqdm/4.58.0 CPython/3.6.13

File hashes

Hashes for tensorguard-1.0.0.tar.gz
Algorithm Hash digest
SHA256 0f72f7b0ec0eaf7d79972a7ef3aae61d0c9be2bf84d9736ffa606514a37bb6ac
MD5 074cd3d74306058674858af97ba43895
BLAKE2b-256 9bd9a9c3fa8294c62d15de7cc234ab7e004bdf4efcea6b3d9e0dbd5d25634df9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page