Skip to main content

Provides type annotations for torch tensors.

Project description

Basic usages (only specify the count of dimensions):

import torch
from statictorch import *

def f(x: Tensor2d):
    pass

f(torch.zeros([2, 3]))  # No runtime error, but static type checkers might say: "Tensor" is not assignable to "Tensor2d"
f(Tensor2d(torch.zeros([2, 3])))  # It's ok.

Note that TensorXd() directly return the given tensor. So at runtime you can't distinguish x and TensorXd(x), and isinstance(TensorXd(x), TensorXd) will return False.

TensorXd is defined as generic classes, so that you could add a dimension descriptor for them:

from statictorch import *

# In previous versions, inheriting from TensorDimensionDescriptor was mandatory. 
# While this requirement is now relaxed to accommodate the new TypeVarTuple-based TensorNd, 
# it remains highly recommended for better semantic clarity and consistency.
class Batch(TensorDimensionDescriptor):
    pass

class Channel(TensorDimensionDescriptor):
    pass

class Sample(TensorDimensionDescriptor):
    pass

def train_on(x: Tensor3d[Batch, Channel, Sample], y: Tensor3d[Channel, Batch, Sample]):
    ...

def load_data() -> tuple[Tensor3d[Batch, Channel, Sample], Tensor3d[Batch, Channel, Sample]]:
    ...

data_x, data_y = load_data()
train_on(data_x, data_y)  # Pylance: Argument of type "Tensor3d[Batch, Channel, Sample]" cannot be assigned to parameter "y" of type "Tensor3d[Channel, Batch, Sample]" in function "train_on"

# To solve the problem:
data_y_transposed = data_y.transpose(1, 0)
train_on(data_x, Tensor3d(data_y_transposed))

# Of course in some cases you want to force passing data_y, just simply cheating the type checker with:
train_on(data_x, Tensor3d(data_y))

typing.cast is also a good idea, especially when you want to call functions like torch.stack on a list[TensorXd]:

import statictorch
import typing
from torch import Tensor

def work_on(tensors: list[Tensor]):
    ...

my_tensors: list[statictorch.Tensor0d] = []
work_on(my_tensors)  # Pylance: Argument of type "list[Tensor0d]" cannot be assigned to parameter "tensors" of type "list[Tensor]" in function "work_on"

# To solve the problem:
work_on(typing.cast(list[Tensor], my_tensors))

# If you find typing.cast too long:
work_on(statictorch.anify(my_tensors))

# If the function is defined by yourself, use Sequence when applicable:
def my_work_on(tensors: typing.Sequence[Tensor]):
    ...
my_work_on(my_tensors)  # Ok, as Sequence is covariant.

In the new version, we introduce TensorNd, which supports an arbitrary number of dimensions:

from typing import Any
import torch
from statictorch.tensor_nd import Tensor3d, TensorNd


# a 15-d tensor
t15: TensorNd[Any, Any, Any, Any, Any,
              Any, Any, Any, Any, Any, 
              Any, Any, Any, Any, Any] = TensorNd(torch.zeros([1] * 15))


t2: TensorNd[Any, Any] = TensorNd(torch.zeros([1, 1]))
t3: TensorNd[Any, Any, Any] = t2  # Pylance: Type "TensorNd[Any, Any]" is not assignable to declared type "TensorNd[Any, Any, Any]"


# Due to technical limitations in Python, we must define TensorXd using inheritance.
# As a result, a TensorNd cannot be directly assigned to a TensorXd even if their dimensions match.
t3_: Tensor3d = t3  # Pylance: Type "TensorNd[Any, Any, Any]" is not assignable to declared type "Tensor3d[Unknown, Unknown, Unknown]"
# Conversely, a Tensor3d can be used directly as a TensorNd.
t3 = t3_

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

statictorch-0.1.1.tar.gz (2.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

statictorch-0.1.1-py3-none-any.whl (4.2 kB view details)

Uploaded Python 3

File details

Details for the file statictorch-0.1.1.tar.gz.

File metadata

  • Download URL: statictorch-0.1.1.tar.gz
  • Upload date:
  • Size: 2.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.28 {"installer":{"name":"uv","version":"0.9.28","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":null,"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for statictorch-0.1.1.tar.gz
Algorithm Hash digest
SHA256 43ec1bd281c539b8298594ea623677207bd45fc2dcdfdd31b323b7de20791b23
MD5 beb86330acb3d83f52feba41f4f69076
BLAKE2b-256 decdab11ceaa9cc51b4697edecfed2f131eb11a0952176ffa354073565078483

See more details on using hashes here.

File details

Details for the file statictorch-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: statictorch-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 4.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.28 {"installer":{"name":"uv","version":"0.9.28","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":null,"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for statictorch-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9f0c333f0799a38bb369784467fed243850b72a0299dc594a1745c57365eaafc
MD5 c55120482f39f92067f180b1ad0e7341
BLAKE2b-256 f6d94d0316c914f92cf6fd6b10e0202fb4a184172b9413e80693e91d855f7673

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page