Skip to main content

Pydantic support for parsing, validation, and serialization of tensors

Project description

pydantic-tensor

Support parsing, validation, and serialization of common tensors (np.ndarray, torch.Tensor, tensorflow.Tensor, jax.Array) for Pydantic.

PyPI - Version PyPI - Python Version


Installation

pip install pydantic-tensor

Usage

Validation

from typing import Annotated, Any, Literal

import numpy as np
import tensorflow as tf
import torch
from pydantic import BaseModel, Field

from pydantic_tensor import Tensor

# allow only integers greater equal than 2 and less equal than 3
DimType = Annotated[int, Field(ge=2, le=3)]


class Model(BaseModel):
    #              tensor type                          shape                    dtype
    tensor: Tensor[torch.Tensor | np.ndarray[Any, Any], tuple[DimType, DimType], Literal["int32", "int64"]]


parsed = Model.model_validate({"tensor": np.ones((2, 2), dtype="int32")})
# access the parsed tensor via the "value" property
parsed.tensor.value

# invalid shapes
Model.model_validate({"tensor": np.ones((1, 1), dtype="int32")})
Model.model_validate({"tensor": np.ones((4, 4), dtype="int32")})
Model.model_validate({"tensor": np.ones(2, dtype="int32")})
Model.model_validate({"tensor": np.ones((2, 2, 2), dtype="int32")})

# invalid dtype
Model.model_validate({"tensor": np.ones((2, 2), dtype="float32")})

# successfully validate np.ndarray
Model.model_validate({"tensor": np.ones((2, 2), dtype="int32")})
# convert tf.Tensor to torch.Tensor
Model.model_validate({"tensor": tf.ones((2, 2), dtype=tf.int32)})

Parsing

The JSON representation of the tensor contains the:

  • binary data of the tensor in little-endian format encoded in Base64
  • shape of the tensor
  • datatype of the tensor
from typing import Any

import numpy as np
from pydantic import BaseModel

from pydantic_tensor import Tensor


class Model(BaseModel):
    tensor: Tensor[Any, Any, Any]


parsed = Model.model_validate({"tensor": np.ones((2, 2), dtype="float32")})
# parse to JSON: {"tensor":{"shape":[2,2],"dtype":"float32","data":"AACAPwAAgD8AAIA/AACAPw=="}}
json_dump = parsed.model_dump_json()
# parse back to tensor: array([[1., 1.], [1., 1.]], dtype=float32)
Model.model_validate_json(json_dump).tensor.value

DType Collections

Types Int, UInt, Float, Complex, BFloat from pydantic_tensor.types are unions of dtypes according to their names. For Example Int is defined as Literal["int8", "int16", "int32", "int64"].

from typing import Any

import numpy as np
from pydantic import BaseModel

from pydantic_tensor import Tensor
from pydantic_tensor.types import Int


class Model(BaseModel):
    tensor: Tensor[Any, Any, Int]


for dtype in ["int8", "int16", "int32", "int64"]:
    Model.model_validate({"tensor": np.ones((2, 2), dtype=dtype)})  # success

Model.model_validate({"tensor": np.ones((2, 2), dtype="float32")})  # failure

Lazy Tensors

Use JaxArray, NumpyNDArray, TensorflowTensor, TorchTensor for lazy versions of tensors types. They only handle tensors when their equivalent libraries (jax, numpy, tensorflow, torch) are imported somewhere else in the program.

from typing import Any

import numpy as np
from pydantic import BaseModel

from pydantic_tensor import Tensor
from pydantic_tensor.backend.torch import TorchTensor


class Model(BaseModel):
    tensor: Tensor[TorchTensor, Any, Any]


Model.model_validate({"tensor": np.ones((2, 2), dtype="float32")})  # failure

import torch

Model.model_validate({"tensor": np.ones((2, 2), dtype="float32")})  # success

Development

Install pre-commit hooks

pre-commit install

Lint

hatch run lint:all

Test

hatch run test:test

Check spelling

hatch run spell

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pydantic_tensor-0.2.0.tar.gz (14.2 kB view details)

Uploaded Source

Built Distribution

pydantic_tensor-0.2.0-py3-none-any.whl (16.3 kB view details)

Uploaded Python 3

File details

Details for the file pydantic_tensor-0.2.0.tar.gz.

File metadata

  • Download URL: pydantic_tensor-0.2.0.tar.gz
  • Upload date:
  • Size: 14.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.27.0

File hashes

Hashes for pydantic_tensor-0.2.0.tar.gz
Algorithm Hash digest
SHA256 2ded2f5f344aed56a894beca13898a8a5de16b99ca8acc5412abb6d6ca38fd3c
MD5 e39e86d0ac0dd3bec0538c81bc8ab9ef
BLAKE2b-256 dce012cad7e7b0b6cf96a7ef2db4adf4b9fec8142dff94cb8ca3f08bafd847af

See more details on using hashes here.

File details

Details for the file pydantic_tensor-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pydantic_tensor-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b78274dcf555d098e4d9946d4feff4324022d8f9868feb716ee8435c7232f601
MD5 b5006f32e7f9e1a05edaa0dad6fa9e7c
BLAKE2b-256 8e63d24eab7e8ecaca07440d40c10676d5af157df7fba0503e04d170bd8d0b0a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page