Skip to main content

Check shape, ndim and dtype of tensor/ndarray of input of function

Project description

ArrayContract

from arraycontract import shape, _
import torch

@shape(x=(_, 'N'), y=('N', _))
def matrix_dot(x, y):
    return x @ y

matrix_dot(torch.rand(3,4), torch.rand(4,5)) # OK
matrix_dot(torch.rand(3,4), torch.rand(3,5)) # raise AssertionError
from arraycontract import shape, _
import torch
from torch import nn

linear = nn.Linear(3, 4)

@shape((..., 3))
def forward_linear(x):
    """
    requires x.shape[-1] == 3
    """
    return linear(x)

forward_linear(torch.rand(4,5,3)) # OK
forward_linear(torch.rand(4,4)) # raise AssertionError
from arraycontract import dtype
from arraycontract import ndim
import torch

@ndim(x=3, y=4)
def ndim_contract(x, y):
    print("requires x.ndim == 3 and y.ndim == 4")

@dtype(x=torch.long)
def dtype_contract(x):
    print("requires x.dtype == torch.long")
from arraycontract import Trigger
from arraycontract import dtype
import torch

Trigger.dtype_check_trigger = False

@dtype(x=torch.long)
def dtype_contract(x):
    print("not requires x.dtype == torch.long")

dtype_contract(torch.rand(3, 4).float()) # OK

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

arraycontract-0.0.2.tar.gz (3.8 kB view hashes)

Uploaded Source

Built Distribution

arraycontract-0.0.2-py3-none-any.whl (6.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page