Easily serialize dataclasses to and from tensors
Project description
Dataclasses Tensor
The library provides a simple API for encoding and decoding Python dataclasses
to and from tensors (PyTorch tensors or NumPy arrays) based on typing
annotations.
Heavily inspired by dataclasses-json
package.
Install
pip install dataclasses-tensor
Quickstart
Tensor representation for a game state in Chess:
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, List
from dataclasses_tensor import dataclass_tensor, config
class Player(Enum):
WHITE = 0
BLACK = 1
class PieceType(Enum):
PAWN = 0
BISHOP = 1
KNIGHT = 2
ROOK = 3
QUEEN = 4
KING = 5
@dataclass
class Piece:
piece_type: PieceType
owner: Player
@dataclass_tensor
@dataclass
class Chess:
num_moves: float
next_move: Player
board: List[Optional[Piece]] = field(metadata=config(shape=(64,)))
Working with tensors:
>>> state = Chess(100., next_move=Player.WHITE, board=[Piece(PieceType.KING, Player.BLACK)])
>>> t1 = state.to_numpy()
array([100., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
...
>>> t1.shape
(579,)
>>> Chess.from_numpy(t1)
Chess(num_moves=100., next_move=<Player.WHITE: 0>, board=[Piece(piece_type=<PieceType.KING: 5>, owner=<Player.BLACK: 1>), ...])
Types
Data Classes
The library uses type annotations to determine appropriate encoding layout. Data class member variables serialized sequentially. See supported types listed below.
Primitives (int, float, bool)
The library supports numerical primitives (int
, float
) and bool
. Strings and byte arrays are not supported.
Warning: be careful with tensor dtype
as an implicit type conversion could potentially lead to losing information (for example, writing float
into int32
tensor and reading it back won't produce expected result).
Enums
Python Enums
are encoded using one-hot encoding.
>>> from dataclasses_tensor import dataclass_tensor
>>> from dataclasses import dataclass
>>> from enum import Enum
>>>
>>> class Matrix(Enum):
... THE_MATRIX = 1
... RELOADED = 2
... REVOLUTIONS = 3
...
>>> @dataclass_tensor
... @dataclass
... class WatchList:
... matrix: Matrix
...
>>> WatchList(Matrix.RELOADED).to_numpy()
array([0., 0., 1.])
>>> WatchList.from_numpy(_)
WatchList(matrix=<Matrix.RELOADED: 2>)
Optional
typing.Optional
type is encoded using additional dimension prior to the main datatype.
>>> from typing import Optional
>>>
>>> @dataclass_tensor
... @dataclass
... class MaybeWatchList:
... matrix: Optional[Matrix]
>>>
>>> MaybeWatchList(Matrix.RELOADED).to_numpy()
array([0., 0., 1., 0.])
>>> MaybeWatchList.from_numpy([0., 0., 1., 0.])
MaybeWatchList(matrix=<Matrix.RELOADED: 2>)
>>> MaybeWatchList.from_numpy([1., 0., 0., 0.])
MaybeWatchList(matrix=None)
The layout described for Optional[Enum]
is consistent with having None
as additional option into enumeration.
Arrays
Arrays, defined either using typing.List
or []
(supported in Python3.9+), require size to be statically provided. See example:
>>> from typing import List
>>> from dataclasses_tensor import config
>>> @dataclass_tensor
... @dataclass
... class MultipleWatchList:
... matrices: List[Matrix] = field(metadata=config(shape=(2,)))
>>>
>>> MultipleWatchList([Matrix.THE_MATRIX, Matrix.RELOADED]).to_numpy()
array([1., 0., 0., 0., 1., 0.])
>>> MultipleWatchList.from_numpy([1., 0., 0., 0., 1., 0.])
MultipleWatchList(matrices=[<Matrix.THE_MATRIX: 1>, <Matrix.RELOADED: 2>])
Nested lists are supported, note multidimensional shape
configuration:
>>> @dataclass_tensor
... @dataclass
... class MultipleWatchList:
... matrices: List[List[Matrix]] = field(metadata=config(shape=(1,2)))
>>>
>>> MultipleWatchList([[Matrix.THE_MATRIX, Matrix.RELOADED]]).to_numpy()
array([1., 0., 0., 0., 1., 0.])
>>> MultipleWatchList.from_numpy([1., 0., 0., 0., 1., 0.])
MultipleWatchList(matrices=[[<Matrix.THE_MATRIX: 1>, <Matrix.RELOADED: 2>]])
If List
argument is Optional
, the list is automatically padded to the right shape with None
s.
>>> @dataclass_tensor
... @dataclass
... class MaybeMultipleWatchList:
... matrices: List[Optional[Matrix]] = field(metadata=config(shape=(3,)))
>>>
>>> MaybeMultipleWatchList([Matrix.THE_MATRIX, Matrix.RELOADED]).to_numpy()
array([0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.])
>>> MaybeMultipleWatchList.from_numpy([0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.])
MaybeMultipleWatchList(matrices=[<Matrix.THE_MATRIX: 1>, <Matrix.RELOADED: 2>, None])
Union
typing.Union
is encoded by allocating one-hot tensor to determine which option from the union is given following by corresponding layouts for all options.
>>> from typing import Union
>>>
>>> class Batman(Enum):
... BEGINS = 1
... DARK_KNIGHT = 2
... DARK_KINGHT_RISES = 3
...
>>> @dataclass_tensor
... @dataclass
... class WatchList:
... next_movie: Union[Matrix, Batman]
...
>>> WatchList(Matrix.RELOADED).to_numpy()
array([1., 0., 0., 1., 0., 0., 0., 0.])
>>> WatchList.from_numpy(_)
WatchList(next_movie=<Matrix.RELOADED: 2>)
>>> WatchList(Batman.DARK_KNIGHT).to_numpy()
array([0., 1., 0., 0., 0., 0., 1., 0.])
>>> WatchList.from_numpy(_)
WatchList(next_movie=<Batman.DARK_KNIGHT: 2>)
Decoding is a fairly straigtforward process though encoding might be somewhat problematic: Python's typing
is not designed to provide separation-by-construction for union types. The library uses simple isinstance
checks to test out all types provided against a given value, first match is used. The library does not traverse generics, origins, supertypes, etc. So, be diligent defining of Union
.
Recursive Definitions
Recursive definitions, like linked lists, trees, graphs etc, are not supported. From a usability and performance point of view, it's crucial for encoder/decoder to be able to evaluate statically output tensor size.
Targets
The library supports the following containers as tensors:
- NumPy ndarray with
to_numpy
/from_numpy
- PyTorch tensors with
to_torch
/from_torch
The best way to work with TensorFlow tensors is to use NumPy ndarrays and convert result with tensorflow.convert_to_tensor
(as the tensor stored in memory as a ndarray anywyas).
Note, that dependencies are not installed with the library itself (TensorFlow, PyTorch or NumPy) and should be provided at runtime.
Performance
Tensor layout is not cached and is computed for each operation. When performing a lot of operations with class definition staying the same, it makes sense to re-use layout. For example:
>>> class Matrix(Enum):
... THE_MATRIX = 1
... RELOADED = 2
... REVOLUTIONS = 3
...
>>> @dataclass_tensor
... @dataclass
... class WatchList:
... matrix: Matrix
...
>>> layout = WatchList.tensor_layout()
>>> WatchList(Matrix.RELOADED).to_numpy(tensor_layout=layout)
array([0., 0., 1.])
>>> WatchList.from_numpy(_, tensor_layout=layout)
WatchList(matrix=<Matrix.RELOADED: 2>)
Advanced Features
Dtype
The library supports float and integer (long) tensors. The data type could be specified either as a parameter to the dataclass_tensor
decorator (applied to all operations) or independently as an argument to to_tensor
function call. See examples below.
dtype
argument is passed to the corresponding target library, e.g. NumPy (docs), PyTorch (docs) or TensorFlow.
>>> class Matrix(Enum):
... THE_MATRIX = 1
... RELOADED = 2
... REVOLUTIONS = 3
...
>>> @dataclass_tensor
... @dataclass
... class WatchList:
... matrix: Matrix
...
>>> WatchList(Matrix.RELOADED).to_numpy()
array([0., 0., 1.], dtype=float32)
>>> WatchList(Matrix.RELOADED).to_numpy(dtype="int32")
array([0, 0, 1], dtype=int32)
or with defaults setup in a decorator
>>> class Matrix(Enum):
... THE_MATRIX = 1
... RELOADED = 2
... REVOLUTIONS = 3
...
>>> @dataclass_tensor(dtype="int32")
... @dataclass
... class WatchList:
... matrix: Matrix
...
>>> WatchList(Matrix.RELOADED).to_numpy()
array([0, 0, 1], dtype=int32)
Batch
To create batch, use batch=True
parameter. See examples:
>>> class Matrix(Enum):
... THE_MATRIX = 1
... RELOADED = 2
... REVOLUTIONS = 3
...
>>> @dataclass_tensor
... @dataclass
... class WatchList:
... matrix: Matrix
...
>>> WatchList.to_numpy([
... WatchList(Matrix.THE_MATRIX),
... WatchList(Matrix.RELOADED),
... ], batch=True)
array([[1., 0., 0.],
[0., 1., 0.]], dtype=float32)
>>> WatchList.from_numpy(_, batch=True)
[WatchList(next_move=<Matrix.THE_MATRIX: 0>),
WatchList(next_move=<Matrix.RELOADED: 1>)]
batch_size
could be used to provide length hint (to ensure good performance when working with generators):
>>> WatchList.to_numpy((
... WatchList(Matrix.THE_MATRIX),
... WatchList(Matrix.RELOADED),
... ), batch_size=2)
array([[1., 0., 0.],
[0., 1., 0.]], dtype=float32)
>>> WatchList.from_numpy(_, batch_size=2)
[WatchList(next_move=<Matrix.THE_MATRIX: 0>),
WatchList(next_move=<Matrix.RELOADED: 1>)]
Custom Attribute Resolver
TBD
TODO
- Custom attribute resolver (e.g. from dict instead of class instance)
- Pretty-print for tensor layout object
Contributing
- Check for open issues or open a fresh issue to start a discussion around a feature idea or a bug.
- Fork the repository on Github & branch from
main
tofeature-*
to start making your changes. - Write a test which shows that the bug was fixed or that the feature works as expected.
or simply...
- Use it.
- Enjoy it.
- Spread the word.
License
Copyright © 2021, Oleksii Kachaiev.
dataclasses-tensor
is licensed under the MIT license, available at MIT and also in the LICENSE file.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file dataclasses-tensor-0.2.6.macosx-10.15-x86_64.tar.gz
.
File metadata
- Download URL: dataclasses-tensor-0.2.6.macosx-10.15-x86_64.tar.gz
- Upload date:
- Size: 15.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | da54d4d4d7de6e8ed1ac8f78e31ba2cbe43c36bf8409c49fc136774e5383ca66 |
|
MD5 | 9a64c6888ac5b26127393bacd192c5be |
|
BLAKE2b-256 | 73eab1836203072bbb0bd760c05d28685bcfe2b695fa0a400c3ae54a24229fcd |
File details
Details for the file dataclasses_tensor-0.2.6-py3-none-any.whl
.
File metadata
- Download URL: dataclasses_tensor-0.2.6-py3-none-any.whl
- Upload date:
- Size: 10.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e840b5f35c7387c24226c133621fa5d30a3f8293c5a2d0a1fb8c01238f012cf3 |
|
MD5 | 0daa9034f63b48519ba8cdb62cc3bd2c |
|
BLAKE2b-256 | 8b9f4ff4bcb693f922a10038955b00f1f189f1eea165e3e2c80f792271547834 |