Skip to main content

Easily serialize dataclasses to and from tensors

Project description

Dataclasses Tensor

The library provides a simple API for encoding and decoding Python dataclasses to and from tensors (PyTorch, TensorFlow, or NumPy arrays) based on typing annotations.

Heavily inspired by dataclasses-json package.

Install

pip install dataclasses-tensor

Quickstart

Tensor representation for a game state in Chess:

from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, List

from dataclasses_tensor import dataclass_tensor, config

class Player(Enum):
  WHITE = 0
  BLACK = 1

class PieceType(Enum):
  PAWN = 0
  BISHOP = 1
  KNIGHT = 2
  ROOK = 3
  QUEEN = 4
  KING = 5

@dataclass
class Piece:
  piece_type: PieceType
  owner: Player

@dataclass_tensor
@dataclass
class Chess:
  num_moves: float
  next_move: Player
  board: List[Optional[Piece]] = field(metadata=config(shape=(64,)))

Working with tensors:

>>> state = Chess(100., next_move=Player.WHITE, board=[Piece(PieceType.KING, Player.BLACK)])
>>> t1 = state.to_numpy()
array([100.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,
         1.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,
...
>>> t1.shape
(579,)
>>> Chess.from_numpy(t1)
Chess(num_moves=100., next_move=<Player.WHITE: 0>, board=[Piece(piece_type=<PieceType.KING: 5>, owner=<Player.BLACK: 1>), ...])

Types

Data Classes

The library uses type annotations to determine appropriate encoding layout. Data class member variables serialized sequentially. See supported types listed below.

Primitives (int, float, bool)

The library supports numerical primitives (int, float) and bool. Strings and byte arrays are not supported.

Warning: be careful with tensor dtype as an implicit type conversion could potentially lead to losing information (for example, writing float into int32 tensor and reading it back won't produce expected result).

Enums

Python Enums are encoded using one-hot encoding.

>>> from dataclasses_tensor import dataclass_tensor
>>> from dataclasses import dataclass
>>> from enum import Enum
>>>
>>> class Matrix(Enum):
...     THE_MATRIX = 1
...     RELOADED = 2
...     REVOLUTIONS = 3
...
>>> @dataclass_tensor
... @dataclass
... class WatchList:
...     matrix: Matrix
...
>>> WatchList(Matrix.RELOADED).to_numpy()
array([0., 0., 1.])
>>> WatchList.from_numpy(_)
WatchList(matrix=<Matrix.RELOADED: 2>)

Optional

typing.Optional type is encoded using additional dimension prior to the main datatype.

>>> from typing import Optional
>>>
>>> @dataclass_tensor
... @dataclass
... class MaybeWatchList:
...     matrix: Optional[Matrix]
>>>
>>> MaybeWatchList(Matrix.RELOADED).to_numpy()
array([0., 0., 1., 0.])
>>> MaybeWatchList.from_numpy([0., 0., 1., 0.])
MaybeWatchList(matrix=<Matrix.RELOADED: 2>)
>>> MaybeWatchList.from_numpy([1., 0., 0., 0.])
MaybeWatchList(matrix=None)

The layout described for Optional[Enum] is consistent with having None as additional option into enumeration.

Arrays

Arrays, defined either using typing.List or [] (supported in Python3.9+), require size to be statically provided. See example:

>>> from typing import List
>>> from dataclasses_tensor import config

>>> @dataclass_tensor
... @dataclass
... class MultipleWatchList:
...     matrices: List[Matrix] = field(metadata=config(shape=(2,)))
>>>
>>> MultipleWatchList([Matrix.THE_MATRIX, Matrix.RELOADED]).to_numpy()
array([1., 0., 0., 0., 1., 0.])
>>> MultipleWatchList.from_numpy([1., 0., 0., 0., 1., 0.])
MultipleWatchList(matrices=[<Matrix.THE_MATRIX: 1>, <Matrix.RELOADED: 2>])

Nested lists are supported, note multidimensional shape configuration:

>>> @dataclass_tensor
... @dataclass
... class MultipleWatchList:
...     matrices: List[List[Matrix]] = field(metadata=config(shape=(1,2)))
>>>
>>> MultipleWatchList([[Matrix.THE_MATRIX, Matrix.RELOADED]]).to_numpy()
array([1., 0., 0., 0., 1., 0.])
>>> MultipleWatchList.from_numpy([1., 0., 0., 0., 1., 0.])
MultipleWatchList(matrices=[[<Matrix.THE_MATRIX: 1>, <Matrix.RELOADED: 2>]])

If List argument is Optional, the list is automatically padded to the right shape with Nones.

>>> @dataclass_tensor
... @dataclass
... class MaybeMultipleWatchList:
...     matrices: List[Optional[Matrix]] = field(metadata=config(shape=(3,)))
>>>
>>> MaybeMultipleWatchList([Matrix.THE_MATRIX, Matrix.RELOADED]).to_numpy()
array([0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.])
>>> MaybeMultipleWatchList.from_numpy([0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.])
MaybeMultipleWatchList(matrices=[<Matrix.THE_MATRIX: 1>, <Matrix.RELOADED: 2>, None])

Union

typing.Union is encoded by allocating one-hot tensor to determine which option from the union is given following by corresponding layouts for all options.

>>> from typing import Union
>>>
>>> class Batman(Enum):
...     BEGINS = 1
...     DARK_KNIGHT = 2
...     DARK_KINGHT_RISES = 3
...
>>> @dataclass_tensor
... @dataclass
... class WatchList:
...     next_movie: Union[Matrix, Batman]
...
>>> WatchList(Matrix.RELOADED).to_numpy()
array([1., 0., 0., 1., 0., 0., 0., 0.])
>>> WatchList.from_numpy(_)
WatchList(next_movie=<Matrix.RELOADED: 2>)
>>> WatchList(Batman.DARK_KNIGHT).to_numpy()
array([0., 1., 0., 0., 0., 0., 1., 0.])
>>> WatchList.from_numpy(_)
WatchList(next_movie=<Batman.DARK_KNIGHT: 2>)

Decoding is a fairly straigtforward process though encoding might be somewhat problematic: Python's typing is not designed to provide separation-by-construction for union types. The library uses simple isinstance checks to test out all types provided against a given value, first match is used. The library does not traverse generics, origins, supertypes, etc. So, be diligent defining of Union.

Recursive Definitions

Recursive definitions, like linked lists, trees, graphs etc, are not supported. From a usability and performance point of view, it's crucial for encoder/decoder to be able to evaluate statically output tensor size.

Targets

The library supports the following containers as tensors:

Note, that dependencies are not installed with the library itself (TensorFlow, PyTorch or NumPy) and should be provided at runtime.

Performance

Tensor layout is not cached and is computed for each operation. When performing a lot of operations with class definition staying the same, it makes sense to re-use layout. For example:

>>> class Matrix(Enum):
...     THE_MATRIX = 1
...     RELOADED = 2
...     REVOLUTIONS = 3
...
>>> @dataclass_tensor
... @dataclass
... class WatchList:
...     matrix: Matrix
...
>>> layout = WatchList.tensor_layout()
>>> WatchList(Matrix.RELOADED).to_numpy(tensor_layout=layout)
array([0., 0., 1.])
>>> WatchList.from_numpy(_, tensor_layout=layout)
WatchList(matrix=<Matrix.RELOADED: 2>)

Advanced Features

Dtype

The library supports float and integer (long) tensors. The data type could be specified either as a parameter to the dataclass_tensor decorator (applied to all operations) or independently as an argument to to_tensor function call. See examples below.

dtype argument is passed to the corresponding target library, e.g. NumPy (docs), PyTorch (docs) or TensorFlow.

>>> class Matrix(Enum):
...     THE_MATRIX = 1
...     RELOADED = 2
...     REVOLUTIONS = 3
...
>>> @dataclass_tensor
... @dataclass
... class WatchList:
...     matrix: Matrix
...
>>> WatchList(Matrix.RELOADED).to_numpy()
array([0., 0., 1.], dtype=float32)
>>> WatchList(Matrix.RELOADED).to_numpy(dtype="int32")
array([0, 0, 1], dtype=int32)

or with defaults setup in a decorator

>>> class Matrix(Enum):
...     THE_MATRIX = 1
...     RELOADED = 2
...     REVOLUTIONS = 3
...
>>> @dataclass_tensor(dtype="int32")
... @dataclass
... class WatchList:
...     matrix: Matrix
...
>>> WatchList(Matrix.RELOADED).to_numpy()
array([0, 0, 1], dtype=int32)

Batch

To create batch, use batch=True parameter. See examples:

>>> class Matrix(Enum):
...     THE_MATRIX = 1
...     RELOADED = 2
...     REVOLUTIONS = 3
...
>>> @dataclass_tensor
... @dataclass
... class WatchList:
...     matrix: Matrix
...
>>> WatchList.to_numpy([
...     WatchList(Matrix.THE_MATRIX),
...     WatchList(Matrix.RELOADED),
... ], batch=True)
array([[1., 0., 0.],
       [0., 1., 0.]], dtype=float32)
>>> WatchList.from_numpy(_, batch=True)
[WatchList(next_move=<Matrix.THE_MATRIX: 0>),
 WatchList(next_move=<Matrix.RELOADED: 1>)]

batch_size could be used to provide length hint (to ensure good performance when working with generators):

>>> WatchList.to_numpy((
...     WatchList(Matrix.THE_MATRIX),
...     WatchList(Matrix.RELOADED),
... ), batch_size=2)
array([[1., 0., 0.],
       [0., 1., 0.]], dtype=float32)
>>> WatchList.from_numpy(_, batch_size=2)
[WatchList(next_move=<Matrix.THE_MATRIX: 0>),
 WatchList(next_move=<Matrix.RELOADED: 1>)]

Custom Attribute Resolver

TBD

TODO

  • Tests suite for PyTorch and TensorFlow adapters
  • Custom attribute resolver (e.g. from dict instead of class instance)
  • Pretty-print for tensor layout object

Contributing

  • Check for open issues or open a fresh issue to start a discussion around a feature idea or a bug.
  • Fork the repository on Github & branch from main to feature-* to start making your changes.
  • Write a test which shows that the bug was fixed or that the feature works as expected.

or simply...

  • Use it.
  • Enjoy it.
  • Spread the word.

License

Copyright © 2021, Oleksii Kachaiev.

dataclasses-tensor is licensed under the MIT license, available at MIT and also in the LICENSE file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

Built Distribution

dataclasses_tensor-0.2.5-py3-none-any.whl (10.2 kB view details)

Uploaded Python 3

File details

Details for the file dataclasses-tensor-0.2.5.macosx-10.15-x86_64.tar.gz.

File metadata

  • Download URL: dataclasses-tensor-0.2.5.macosx-10.15-x86_64.tar.gz
  • Upload date:
  • Size: 15.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.6

File hashes

Hashes for dataclasses-tensor-0.2.5.macosx-10.15-x86_64.tar.gz
Algorithm Hash digest
SHA256 d43f54d29d2d12c67d096550e1242cb5ca032219f349a60ce621d901da7a7ed3
MD5 2447acac3568a773d467df109198ade4
BLAKE2b-256 7863e0e27dafbae6d591c6cfb5f92879135487f6e32d5caca1e6a92cd0c8f60c

See more details on using hashes here.

File details

Details for the file dataclasses_tensor-0.2.5-py3-none-any.whl.

File metadata

  • Download URL: dataclasses_tensor-0.2.5-py3-none-any.whl
  • Upload date:
  • Size: 10.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.6

File hashes

Hashes for dataclasses_tensor-0.2.5-py3-none-any.whl
Algorithm Hash digest
SHA256 9f9dd0955396866375030568e4932b3c6d7a8d5bfa076423f225215d50f9ab2f
MD5 053668f8c4c543f237313cdfbce1837c
BLAKE2b-256 db1cd51acf106f6ac25f221a1945a664cb2da3b2c020fa682a460ee30aa448b6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page