Check shape, ndim and dtype of tensor/ndarray of input of function
Project description
ArrayContract
from arraycontract import shape, _
import torch
@shape(x=(_, 'N'), y=('N', _))
def matrix_dot(x, y):
return x @ y
matrix_dot(torch.rand(3,4), torch.rand(4,5)) # OK
matrix_dot(torch.rand(3,4), torch.rand(3,5)) # raise AssertionError
from arraycontract import shape, _
import torch
from torch import nn
linear = nn.Linear(3, 4)
@shape((..., 3))
def forward_linear(x):
"""
requires x.shape[-1] == 3
"""
return linear(x)
forward_linear(torch.rand(4,5,3)) # OK
forward_linear(torch.rand(4,4)) # raise AssertionError
from arraycontract import dtype
from arraycontract import ndim
import torch
@ndim(x=3, y=4)
def ndim_contract(x, y):
print("requires x.ndim == 3 and y.ndim == 4")
@dtype(x=torch.long)
def dtype_contract(x):
print("requires x.dtype == torch.long")
from arraycontract import Trigger
from arraycontract import dtype
import torch
Trigger.dtype_check_trigger = False
@dtype(x=torch.long)
def dtype_contract(x):
print("not requires x.dtype == torch.long")
dtype_contract(torch.rand(3, 4).float()) # OK
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
arraycontract-0.0.2.tar.gz
(3.8 kB
view details)
Built Distribution
File details
Details for the file arraycontract-0.0.2.tar.gz
.
File metadata
- Download URL: arraycontract-0.0.2.tar.gz
- Upload date:
- Size: 3.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.1.2 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 798d94acf8c743c5b3f14f0f2eddbec6717f83d054f90bc7a62195c8f488a754 |
|
MD5 | bdbb7d51995a83c182bfe8a19dc81e96 |
|
BLAKE2b-256 | bc63fa02139d13931d66ce38e1c22f71b877dd31badd2daa6ac9d30c0d06f88a |
File details
Details for the file arraycontract-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: arraycontract-0.0.2-py3-none-any.whl
- Upload date:
- Size: 6.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.1.2 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 65bdc9cf8f495aee205d101f2974ab8153bb7e385f4ae23e23572fc76c78a6ab |
|
MD5 | b4479903c980c66f94e4b413c7ecb8bd |
|
BLAKE2b-256 | ba84266438a9e442809c93f95f1d0ef1bd9ab4274b5862287b416ce98021c382 |