Provides type annotations for torch tensors.
Project description
Basic usages (only specify the count of dimensions):
import torch
from statictorch import *
def f(x: Tensor2d):
pass
f(torch.zeros([2, 3])) # No runtime error, but static type checkers might say: "Tensor" is not assignable to "Tensor2d"
f(Tensor2d(torch.zeros([2, 3]))) # It's ok.
Note that TensorNd() directly return the given tensor. So at runtime you can't distinguish x and TensorNd(x), and isinstance(TensorNd(x), TensorNd) will return False.
TensorNd is defined as generic classes, so that you could add a dimension descriptor for them:
from statictorch import *
class Batch(TensorDimensionDescriptor):
pass
class Channel(TensorDimensionDescriptor):
pass
class Sample(TensorDimensionDescriptor):
pass
def train_on(x: Tensor3d[Batch, Channel, Sample], y: Tensor3d[Channel, Batch, Sample]):
...
def load_data() -> tuple[Tensor3d[Batch, Channel, Sample], Tensor3d[Batch, Channel, Sample]]:
...
data_x, data_y = load_data()
train_on(data_x, data_y) # Pylance: Argument of type "Tensor3d[Batch, Channel, Sample]" cannot be assigned to parameter "y" of type "Tensor3d[Channel, Batch, Sample]" in function "train_on"
# To solve the problem:
data_y_transposed = data_y.transpose(1, 0)
train_on(data_x, Tensor3d(data_y_transposed))
# Of course in some cases you want to force passing data_y, just simply cheating the type checker with:
train_on(data_x, Tensor3d(data_y))
typing.cast is also a good idea, especially when you want to call functions like torch.stack on a list[TensorNd]:
from statictorch import *
import typing
from torch import Tensor
def work_on(tensors: list[Tensor]):
...
my_tensors: list[Tensor0d] = []
work_on(my_tensors) # Pylance: Argument of type "list[Tensor0d]" cannot be assigned to parameter "tensors" of type "list[Tensor]" in function "work_on"
# To solve the problem:
work_on(typing.cast(list[Tensor], my_tensors))
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file statictorch-0.0.2.tar.gz.
File metadata
- Download URL: statictorch-0.0.2.tar.gz
- Upload date:
- Size: 3.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
82db587893cdddfeb47fae771cbe356adb81865b3ef2db38b39b58db1dcee9e9
|
|
| MD5 |
0bd74314db48f5a0547f7341d200a83f
|
|
| BLAKE2b-256 |
98fd0fee1a57f0a7b9c0d31607702466be2360c30fd31c5c7151297b7f359edc
|
File details
Details for the file statictorch-0.0.2-py3-none-any.whl.
File metadata
- Download URL: statictorch-0.0.2-py3-none-any.whl
- Upload date:
- Size: 3.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e1a18084239fde98f915e927b1b36790d9139e87e5cfecdbc2538618fd0cfe25
|
|
| MD5 |
625121636a96c5f39494e9958d42cc57
|
|
| BLAKE2b-256 |
786802ed4052ac443bd5b456c8fdf12273f3937a5b1ff323eb771dc385250b49
|