Skip to main content

Provides type annotations for torch tensors.

Project description

Basic usages (only specify the count of dimensions):

import torch
from statictorch import *

def f(x: Tensor2d):
    pass

f(torch.zeros([2, 3]))  # No runtime error, but static type checkers might say: "Tensor" is not assignable to "Tensor2d"
f(Tensor2d(torch.zeros([2, 3])))  # It's ok.

Note that TensorXd() directly return the given tensor. So at runtime you can't distinguish x and TensorXd(x), and isinstance(TensorXd(x), TensorXd) will return False.

TensorXd is defined as generic classes, so that you could add a dimension descriptor for them:

from statictorch import *

# In previous versions, inheriting from TensorDimensionDescriptor was mandatory. 
# While this requirement is now relaxed to accommodate the new TypeVarTuple-based TensorNd, 
# it remains highly recommended for better semantic clarity and consistency.
class Batch(TensorDimensionDescriptor):
    pass

class Channel(TensorDimensionDescriptor):
    pass

class Sample(TensorDimensionDescriptor):
    pass

def train_on(x: Tensor3d[Batch, Channel, Sample], y: Tensor3d[Channel, Batch, Sample]):
    ...

def load_data() -> tuple[Tensor3d[Batch, Channel, Sample], Tensor3d[Batch, Channel, Sample]]:
    ...

data_x, data_y = load_data()
train_on(data_x, data_y)  # Pylance: Argument of type "Tensor3d[Batch, Channel, Sample]" cannot be assigned to parameter "y" of type "Tensor3d[Channel, Batch, Sample]" in function "train_on"

# To solve the problem:
data_y_transposed = data_y.transpose(1, 0)
train_on(data_x, Tensor3d(data_y_transposed))

# Of course in some cases you want to force passing data_y, just simply cheating the type checker with:
train_on(data_x, Tensor3d(data_y))

typing.cast is also a good idea, especially when you want to call functions like torch.stack on a list[TensorXd]:

import statictorch
import typing
from torch import Tensor

def work_on(tensors: list[Tensor]):
    ...

my_tensors: list[statictorch.Tensor0d] = []
work_on(my_tensors)  # Pylance: Argument of type "list[Tensor0d]" cannot be assigned to parameter "tensors" of type "list[Tensor]" in function "work_on"

# To solve the problem:
work_on(typing.cast(list[Tensor], my_tensors))

# If you find typing.cast too long:
work_on(statictorch.anify(my_tensors))

# If the function is defined by yourself, use Sequence when applicable:
def my_work_on(tensors: typing.Sequence[Tensor]):
    ...
my_work_on(my_tensors)  # Ok, as Sequence is covariant.

In the new version, we introduce TensorNd, which supports an arbitrary number of dimensions:

from typing import Any
import torch
from statictorch.tensor_nd import Tensor3d, TensorNd


# a 15-d tensor
t15: TensorNd[Any, Any, Any, Any, Any,
              Any, Any, Any, Any, Any, 
              Any, Any, Any, Any, Any] = TensorNd(torch.zeros([1] * 15))


t2: TensorNd[Any, Any] = TensorNd(torch.zeros([1, 1]))
t3: TensorNd[Any, Any, Any] = t2  # Pylance: Type "TensorNd[Any, Any]" is not assignable to declared type "TensorNd[Any, Any, Any]"


# Due to technical limitations in Python, we must define TensorXd using inheritance.
# As a result, a TensorNd cannot be directly assigned to a TensorXd even if their dimensions match.
t3_: Tensor3d = t3  # Pylance: Type "TensorNd[Any, Any, Any]" is not assignable to declared type "Tensor3d[Unknown, Unknown, Unknown]"
# Conversely, a Tensor3d can be used directly as a TensorNd.
t3 = t3_

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

statictorch-0.2.0.tar.gz (2.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

statictorch-0.2.0-py3-none-any.whl (4.3 kB view details)

Uploaded Python 3

File details

Details for the file statictorch-0.2.0.tar.gz.

File metadata

  • Download URL: statictorch-0.2.0.tar.gz
  • Upload date:
  • Size: 2.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.28 {"installer":{"name":"uv","version":"0.9.28","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":null,"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for statictorch-0.2.0.tar.gz
Algorithm Hash digest
SHA256 0e28efcca71aa53eef590dbe69326da5cb2bf0faa7da2f73b0f37f17824ce955
MD5 d2e7f1b29102aa7ce957d8361cbb9f60
BLAKE2b-256 3d635b0438a293622e06c4df9707f13b1332ea1c32bc2323e208e9a3f5e1e460

See more details on using hashes here.

File details

Details for the file statictorch-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: statictorch-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 4.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.28 {"installer":{"name":"uv","version":"0.9.28","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":null,"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for statictorch-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a9d72f055851a7ce6a5c075f87633212a2627113faee9985d3638833a78df57f
MD5 a072d68fcb53bd331ffb102b269997a6
BLAKE2b-256 9e97ab872884b8d136b8a51953c63267feeeb45e3bd3131082b3ab1e8f0c0301

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page