Skip to main content

Provides type annotations for torch tensors.

Project description

Basic usages (only specify the count of dimensions):

import torch
from statictorch import *

def f(x: Tensor2d):
    pass

f(torch.zeros([2, 3]))  # No runtime error, but static type checkers might say: "Tensor" is not assignable to "Tensor2d"
f(Tensor2d(torch.zeros([2, 3])))  # It's ok.

Note that TensorNd() directly return the given tensor. So at runtime you can't distinguish x and TensorNd(x), and isinstance(TensorNd(x), TensorNd) will return False.

TensorNd is defined as generic classes, so that you could add a dimension descriptor for them:

from statictorch import *

class Batch(TensorDimensionDescriptor):
    pass

class Channel(TensorDimensionDescriptor):
    pass

class Sample(TensorDimensionDescriptor):
    pass

def train_on(x: Tensor3d[Batch, Channel, Sample], y: Tensor3d[Channel, Batch, Sample]):
    ...

def load_data() -> tuple[Tensor3d[Batch, Channel, Sample], Tensor3d[Batch, Channel, Sample]]:
    ...

data_x, data_y = load_data()
train_on(data_x, data_y)  # Pylance: Argument of type "Tensor3d[Batch, Channel, Sample]" cannot be assigned to parameter "y" of type "Tensor3d[Channel, Batch, Sample]" in function "train_on"

# To solve the problem:
data_y_transposed = data_y.transpose(1, 0)
train_on(data_x, Tensor3d(data_y_transposed))

# Of course in some cases you want to force passing data_y, just simply cheating the type checker with:
train_on(data_x, Tensor3d(data_y))

typing.cast is also a good idea, especially when you want to call functions like torch.stack on a list[TensorNd]:

from statictorch import *
import typing
from torch import Tensor

def work_on(tensors: list[Tensor]):
    ...

my_tensors: list[Tensor0d] = []
work_on(my_tensors)  # Pylance: Argument of type "list[Tensor0d]" cannot be assigned to parameter "tensors" of type "list[Tensor]" in function "work_on"

# To solve the problem:
work_on(typing.cast(list[Tensor], my_tensors))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

statictorch-0.0.2.tar.gz (3.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

statictorch-0.0.2-py3-none-any.whl (3.4 kB view details)

Uploaded Python 3

File details

Details for the file statictorch-0.0.2.tar.gz.

File metadata

  • Download URL: statictorch-0.0.2.tar.gz
  • Upload date:
  • Size: 3.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for statictorch-0.0.2.tar.gz
Algorithm Hash digest
SHA256 82db587893cdddfeb47fae771cbe356adb81865b3ef2db38b39b58db1dcee9e9
MD5 0bd74314db48f5a0547f7341d200a83f
BLAKE2b-256 98fd0fee1a57f0a7b9c0d31607702466be2360c30fd31c5c7151297b7f359edc

See more details on using hashes here.

File details

Details for the file statictorch-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: statictorch-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 3.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for statictorch-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 e1a18084239fde98f915e927b1b36790d9139e87e5cfecdbc2538618fd0cfe25
MD5 625121636a96c5f39494e9958d42cc57
BLAKE2b-256 786802ed4052ac443bd5b456c8fdf12273f3937a5b1ff323eb771dc385250b49

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page