Skip to main content

Shape annotations for homogeneous numpy arrays and pytorch/tensorflow tensors.

Project description

+----+
|asta|
+----+

Shape annotations for numpy arrays and pytorch/tensorflow tensors.

Introduction
------------
This library defines subscriptable classes ``Array``, ``Tensor``, and
``TFTensor`` for use in ``isinstance()`` checks and type annotations. It also
adds a decorator, ``@typechecked``, which implements toggleable static type
enforcement for the classes described above.


Installation
------------
$ pip install asta
$ pip install git+git://github.com/brendanxwhitaker/asta.git


Basics
------
Asta supports checking dtypes and/or shapes with wildcards and ellipses:

>>> Array[float] # dtype=np.float64.
>>> Array[float, 1, 2, 3] # dtype=np.float64, shape=(1,2,3).
>>> Array[float, (1, 2, 3)] # dtype=np.float64, shape=(1,2,3).
>>> Array[1, 2, 3] # shape=(1,2,3).
>>> Array[int, ()] # dtype=np.int64, shape=().
>>> Array[str, 1, 2 ..., 3] # dtype=unicode, shape=(1,2,*...*,3).
>>> Array[str, 1, 2 -1, 3] # dtype=unicode, shape=(1,2,*,3).
>>> Tensor[torch.uint8, 1] # dtype=torch.uint8, shape=(1,).
>>> TFTensor[tf.complex128, ()] # dtype=tf.complex128, shape=().


Wildcard dimensions (-1)
------------------------
An ``-1`` can be used as a wildcard in place of a single dimension size. They
behave as a stand-in for any positive integer, and they can be used as many
times as desired within a shape hint.


Ellipses (...)
--------------
An ``Ellipsis`` object can be used in place of a nonnegative integer number of
positive integer dimension sizes. The intended use case for this is when you
know the prefix of a tensor shape, the batch size for example, but not the
remaining dimension sizes (perhaps they vary). You can use something like:

>>> Array[32, ...]

to match arrays with shape:

>>> (32,)
>>> (32, 24)
>>> (32, 1, 2, 3)

and so on. These should be a LAST RESORT, since as will be discussed below, we
can specify dimensions, shapes, or portions of shapes as variables, even if
they will change at runtime, e.g. on every iteration of some loop.

Multiple ellipses can be used within a single annotation as well, in case you
happen to know that a tensor or array will have a dimension size of 7 somewhere
in its shape, but you don't know where, or you don't know the total rank.
Something like the following:

>>> TFTensor[..., 7, ...]

will match tensorflow tensors with shape:

>>> (7,)
>>> (1, 7)
>>> (7, 3)
>>> (1, 2, 3, 7, 4, 5, 6)

and so on. Note that two ellipses cannot be used consecutively, e.g.

>>> Array[1, 2, ..., ..., 5]

since asta won't be able to determine which dimensions should be substituted
for which ellipsis object.


Typechecked decorator
---------------------
As mentioned above, asta implements a decorator for runtime typechecking using
type hints like ``Array[]``. It can be used to decorate functions, instance
methods, class methods, metaclass methods, or even entire classes (this is just
equivalent to decorating each of the class's methods).

The following gives an example of using the ``@typechecked`` decorator to
enforce torch tensor shapes and dtypes at runtime. The function ``kl`` will
raise a TypeError if called with inputs which have any dtype other than
``torch.float32``, or any shape other than ``(8, 64)``.

>>> import os
>>> import torch.nn.functional as F
>>> from asta import Tensor, typechecked
>>>
>>> os.environ["ASTA_TYPECHECK"] = "1"
>>>
>>>
>>> @typechecked
>>> def kl(t_1: Tensor[float, 8, 64], t_2: Tensor[float, 8, 64]) -> Tensor[()]:
>>> """ Computes the KL divergence of two Tensors of shape ``(8, 64)``. """
>>> divergence = F.kl_div(t_1, t_2, reduction="sum")
>>> return divergence

A runnable example is given in ``example.py`` in the repository root. For a
more fleshed-out, real world example, see ``asta/tests/rl/`` for a typechecked
policy gradient implementation.


Variable shapes and dimensions
------------------------------
Asta also supports variable shape arguments:

>>> from asta import dims, shapes
>>> BATCH_SIZE = dims.BATCH_SIZE
>>> SEQ_LEN = dims.SEQ_LEN
>>> OBS_SHAPE = shapes.OBS_SHAPE
>>> ...
>>> dims.BATCH_SIZE = 32
>>> dims.SEQ_LEN = 256
>>> shapes.OBS_SHAPE = (4, 5, 6)
>>> ...
>>> Tensor[float, BATCH_SIZE, SEQ_LEN] # shape=(32,256).
>>> Tensor[float, (BATCH_SIZE,) + OBS_SHAPE] # shape=(32,4,5,6).

These can be accessed from ``asta.dims`` for use in annotating top-level
functions and then set after import-time during the execution of a program,
allowing for shape annotations which depend on control flow or which change
during execution:

(utils.py)
>>> from asta import dims
>>> HIDDEN_DIM = dims.HIDDEN_DIM
>>> @typechecked
>>> def identity(x: Tensor[float, HIDDEN_DIM]) -> Tensor[float, HIDDEN_DIM]:
>>> return x
...
(main.py)
>>> from utils import example_fn
>>> dims.HIDDEN_DIM = 768
>>> x = torch.ones((768,))
>>> y = identity(x)

As seen above, this allows dimension or shape constants, stored in
``asta.dims`` and ``asta.shapes`` respectively by setting a named attribute, to
persist as globals across modules. This need not be constants, either. We can
reset the value of ``dims.<NAME>``, and all typechecked function calls
occurring after the aforementioned statement executes will reflect the updated
value of ``dims.<NAME>``. Under the hood, this works because ``dims`` and
``shapes`` are actually instances of classes with the ``__getattr__`` and
``__setattr__`` functions overridden. They return a placeholder object for any
attribute accessed, and then update the value of that placeholder whenever that
attribute is set later on.


Dimension inference
-------------------
Asta also supports arbitrary expressions of symbolic shape arguments:

>>> from asta import symbols
>>> X = symbols.X
>>> Y = symbols.Y
>>> Tensor[float, (X, X**2, X**3)] # e.g. shape=(2,4,8) or shape=(1,1,1).
>>> Tensor[float, (X + Y, (X + Y)**2)] # Multivariate supported as well.

Checks on the above objects will do inference based on the passed tensor,
constructing a system of equations based on the actual shape and returning a
match if there exists at least one solution to the system. This means the
equation solver is lenient, if you pass it a system which has many solutions,
or infinitely many solutions, it will still return a match. Something like:

>>> from asta import symbols
>>> X = symbols.X
>>> Y = symbols.Y
>>> Tensor[X + Y]

used as an annotation for ``torch.ones((50,))``, for example, yields the
equation ``X + Y = 50``, which has many solutions for positive integer values
of ``X`` and ``Y``. But this will not raise an error. Similar to the case with
Ellipses, symbols should only be used as a LAST RESORT, since it is often
possible to make a lazy-definition of a ``dims`` placeholder prior to the
function call instead, which will be more precise. They are useful, however, if
you're averse to peppering your code with dimension storage statements (like
``dims.W = 256`` or something), or if it doesn't matter what the specific
values of the sizes are, only that they obey some constraint relative to each
other.


Configuration
-------------
The behavior of the ``@typechecked`` decorator can be configured through a
configuration file, like ``setup.cfg``, ``pyproject.toml``, ``astarc``, or
``.astarc``. The ``on`` option can also be toggled via an environment variable
called ``ASTA_TYPECHECK``, which, if set to "1", will check, and will do
nothing if set to "0". If the option from the environment variable and the
configuration file conflict, asta will default to ``off``.

An example config file is given below, and in ``asta/defaults/astarc``.

>>> [MASTER]
>>>
>>> on=yes
>>> raise-errors=yes
>>> print-passes=yes
>>> check-non-asta-types=no
>>> check-all-sequence-elements=yes

And explanations of the options:

``on`` : Determines whether or not functions are typechecked. Does as little as
possible other than calling the wrapped function when set to ``no``.
``raise-errors`` : If ``yes`` errors will be raised when a typecheck fails,
otherwise, the error will only be printed.
``print-passes`` : If ``yes`` all passed typechecks will be printed, otherwise,
they will be silent (similar to mypy/pylint output).
``check-non-asta-types`` : If ``yes`` asta will attempt to check the types of
everything passed to or returned from a function. A wide variety of types
are supported. Note: some of the helper functions for these have been
adapted from those in the package ``typeguard``. Otherwise, it will only
check ``Array``, ``Tensor``, and ``TFTensor``, and composites of these,
like ``Dict[str, Tensor[1,2,3]]``.
``check-all-sequence-elements`` : If ``yes``, it will check the types of all
elements in iterable types like ``List[*]``. Otherwise, it will only check the
first element in an attempt to be faster.


Subscript arguments
-------------------
The valid subscript arguments for ``Array`` and ``Tensor`` are as follows:

Types
-----

Array
-----
1. Any python type from the following list:
a. int
b. float
c. bool
d. complex
e. bytes
f. str
g. datetime
h. timedelta
2. Any numpy dtype, e.g. ``np.int64``.
3. Omitted (no argument passed).

Tensor
------
1. Any python type from the following list:
a. int
b. float
c. bool
d. bytes
2. Any ``torch.Tensor``-supported torch dtype, e.g. ``torch.int64``.
3. Omitted (no argument passed).

TFTensor
--------
1. Any python type from the following list:
a. int
b. float
c. bool
d. bytes
2. Any ``torch.Tensor``-supported tensorflow dtype, e.g. ``tf.int64``.
3. Omitted (no argument passed).

Shapes
------
1. Nonnegative integers.
2. ``-1``, a wildcard for any positive integer size.
3. Ellipses (``...``), a placeholder for any contiguous sequence of
positive integer sizes, including zero-length sequences.
4. ``()`` or ``Scalar``, which both indicate a scalar array or tensor.
These are interchangeable.
5. Dimensions from ``asta.dims``.
6. Shapes from ``asta.shapes``.
7. Symbols from ``asta.symbols``.
8. Omitted (no argument passed).


Shape constraints and best practices
------------------------------------
There is a key difference between the way scalar values are handled in numpy
and the way they are handled in torch. Consider an array/tensor of shape
``(2,2,2)``. When indexing the first element along each dimension, which should
be scalar, we call ``a[0,0,0]``, where ``a`` is our array. We do the same to a
tensor ``t``, and assign the result to variables ``x`` and ``y``, respectively:

>>> a = np.zeros((2,2,2))
>>> t = torch.zeros((2,2,2))
>>> x = a[0,0,0]
>>> y = t[0,0,0]

What are the types of ``x`` and ``y``?

>>> type(x)
<class 'numpy.float64'>
>>> type(y)
<class 'torch.Tensor'>

Interestingly enough, while ``a`` is of type ``np.ndarray``, ``x`` is of type
``np.float64``, a subclass of float, while both ``t`` and ``y`` are tensors.
Note that ``x`` is not an array:

>>> isinstance(x, np.ndarray)
False
>>> isinstance(x, float)
True

And ``y`` is not a float:

>>> isinstance(y, torch.Tensor)
True
>>> isinstance(y, float)
False

Asta does not attempt to rectify this discrepancy, and so the behavior of
``Array`` and ``Tensor`` when it comes to scalars is slightly different. In the
above, ``x`` is not an instance of ``Array``, while ``y`` is an instance of
``Tensor``, even though they are indexed in an identical manner.

Use of the ellipsis placeholder (``...``) in asta is meant to mirror its usage
in numpy/torch indexing. This is why we allow ``...`` to take the place of an
emtpy shape. Note that indexing scalar arrays with an ellipsis returns the
array unchanged:

>>> np.zeros(())
array(0.)
>>> a = np.zeros(())
>>> a[...]
array(0.)

And adding an Ellipsis anywhere in an already-filled index tuple will return a
scalar array with the expected value:

>>> a = np.zeros((2,2,2))
>>> a[0,0,0]
0.0
>>> a[0,0,0,...]
array(0.)
>>> a[0,0,...,0]
array(0.)
>>> a[...,0,0,0]
array(0.)

In contrast, we take the ``-1`` wildcard argument to represent only a single
positive integer shape element. So if you wanted to all arrays with shape
``(1,*,1)``, where ``*`` is nonempty, i.e. don't match arrays of shape
``(1,1)``, but do match any of the following:

1. ``(1, 1, 1)``
2. ``(1, 2, 1)``
3. ``(1, 2, 3, 4, 5, 1)``

You would use ``Array[1,-1,...,1]``.


Performance
-----------
The runtime checking functionality of asta is NOT meant to be used in
situations where performance/speed is critical. Furthermore, use of the values
of type hints within python code, which ``@typechecked`` decorator relies on,
is not recommended; the ability to type hint in python is meant to be just
that, a hint. The usefulness of using the decorator is as a debugging or
testing step when working on large, complicated models or workflows. Many of
the native numpy and pytorch array/tensor functions allow arbitrary shaped
inputs, and it is easy for a malformed shape to pass unnoticed, with no effects
other than poor downstream performance or results. Asta is meant to be a crude
aide in dealing with this common problem, but by no means a comprehensive one.

This having been said, the ``isinstance()`` checks used are relatively cheap,
and shouldn't cause a serious slowdown outside of exceptional cases.

The recommended usage of this library would be to annotate all critical
functions which take or return ndarrays/tensors, and decorate them with
``@typechecked``. One could then add a CI test which sets the
``ASTA_TYPECHECK`` environment variable to ``1`` and runs a sample workflow.
Any incorrect dtypes or malformed shapes will raise a TypeError, and typechecks
which pass will print to stdout. This behavior is intentional, and meant to
help researchers avoid silent performance degradation due to leaving the
environment variable set, which will cause a slight slowdown which would
otherwise occur silently.


========
INTERNAL
========


Todo
----
- Add ``# type: ignore`` comments in test files. The ``[type-arg]`` and
ellipses errors will be ignored when the package is installed. They just need
to be silenced within the package itself. (DONE)
- Delete ``demo.py``. (DONE)
- Implement ``-1`` wildcard shape element suppport. (DONE)
- Add tests for ``Tensor``. (DONE)
- Write examples in README. (DONE)
- Add tests for empty arrays and tensors. (DONE)
- Consider making Ellipses only match empty shapes and positive integer shapes,
but not zero-valued shapes. This would be done under the assumption that most
people are not interested in working with empty arrays/tensors. And if they
are, they can use the ``.size`` attribute for an easy check. (DONE)
- Add reprs. (DONE)
- Fix tensor strategy. (DONE)
- Add an option to disable typechecked decorator (default=disabled, ``.astarc``
file).
- Add environment variable for typechecking. (DONE)
- Add tests for ``@typechecked``. (DONE)
- Consider changing name of decorator to ``@shapechecked``. (NO)
- Consider dropping the ``Scalar`` object. The less unfamiliar objects, the
better. (NO)
- Add more descriptive error if you pass torch dtype for an Array or numpy
dtype for a Tensor. (DONE)
- Add ``CudaTensor`` class. (DONE)
- Consider adding support for arbitrary shape/type constraints as subscript
arguments. They would be of the form:

>>> def fn(shape: Tuple[int, ...]) -> bool:
>>> n: int
>>> for i, elem in enumerate(shape):
>>> if i == 0:
>>> n = elem
>>> elif elem == 2 * n:
>>> n = elem
>>> else:
>>> return False
>>> return True

The above constraint would only pass for shapes ``(n, 2n, 4n, 8n...)``.
To enforce it, you would use ``Array[fn]``. (NO)
- Consider allowing instantiation to support kwargs. (DETAILED BELOW)
- Consider adding a delimiter in the typecheck decorator output to distinguish
the typecheck passes between different functions and signatures. Something
like ``===============<asta.typechecked_function()>===============``. (DONE)
- Consider adding support for tensorflow. (DONE)
- Consider adding support for passing ``()`` as in ``Array[int, ()]`` to denote
scalar arrays instead of ``Scalar`` or ``None``. (DONE)
- Write tests for ``Scalar``.
- Write analogues of new ``Array`` tests for ``Tensor``. (DONE)
- Consider reserving ``None`` shape and type for unintialized shape globals.
(NO)
Attempting to typecheck with them will raise an error warning that they are
uninitialized. Then you can set your dim variable defaults in config module
all to ``None``. (DONE)
- Add uninitialized dimension size error to decorator. See preceding note.
(DONE)
- Add section in README about using ``asta.dims`` for programmatically-set
dimension sizes in shape annotations.
- Fix base class of ``_ArrayMeta`` and ``_TensorMeta`` so that type ignore is
not needed in ``decorators.py``.
- Consider removing library-specific metaclasses.
- Consider making ``parse_subscript`` a classmethod of ``SubscriptableMeta``.
(NO)
- Remove torch, tensorflow from requirements. You should be able to ues
``Array`` without torch installed. (DONE)
- When you try to import Tensor when torch is not installed, no error should be
raised. But when you then try to use/subscript Tensor, an error should be
raised. So it should still be a subscriptable type, but the subscription
method should try to import torch. (DONE)
- Consider adding a ``soft`` mode where the decorator only prints errors
instead of raising them, so that a user could see and correct all of them at
once. (DONE)
- Consider adding support for checking arbitrary attributes on arrays and
tensors. For example, if working on a reinforcement learning problem, and you
wanted to make sure all tensors have the same timestep index, you could seta
``<torch.Tensor>.timestep: int`` attribute and then assert that they all
match at function call time. (DONE)
- This could have syntax like:

>>> def product(
>>> t1: Tensor(float, 1,2,3, timestep=K),
>>> t2: Tensor(float, 1,2,3, timestep=K + 1),
>>> ) -> Tensor(float, 1,2,3):

This would require implementing an ``__init__()`` for array classes, which
would just return an instance of the metaclass with the appropriate class
variables set. Alternatively, we could use a dictionary:

>>> def product(
>>> t1: Tensor[float, 1,2,3, {"timestep": K}],
>>> t2: Tensor[float, 1,2,3, {"timestep": K + 1}],
>>> ) -> Tensor[float, 1,2,3]:

This is much better because it doesn't break the convention of using
subscriptable objects for annotations. (DONE)

- This would also require an implementation of non-contant shapes/variables.
For example, if you know all your tensors will have length ``k``, but ``k``
is variable, then you import ``K`` from some asta namespace. This will return
a placeholder-like object with no set value. At typecheck time, it will be
inferred from the first argument checked which makes use of this variable.
The typechecker will assert that all other instances of ``K`` have the same
value as this initialized one, and if soft checking is enabled, it will print
out all of them along with an error message. (DONE)

- Rewrite README.
- Fix source code headers. (DONE)
_ Add support for generics:
- List (DONE)
_ Sequence (DONE)
- Dict (DONE)
- Set (DONE)
- Callable (DONE)
- Write tests for decorated classes, attributes, equation solver, placeholders,
instance/class/metaclass methods, and the above generic subscriptable types.
- Resolve the discrepancy between default dtypes of tensors (float32 in TF),
and default dtype given the primitive float argument (float64). What is the
most natural behavior? Follow PLA. (DONE)
- In line with above, consider making checks for primitive types (int, float,
complex, bytes, str) less strict. This should be done if there is time. (NO)
- Put on PyPI.
- Bug: disabling/enabling asta per-file is likely broken, since importing check
sets the value of an environment variable, which is global during python
execution. (DONE)
- Consider adding ``asta.shp`` and ``asta.dims`` so that integer and tuple
dimensions/shapes are in two separate storage modules. This would also mean
you could write ``shp.OBS`` intead of ``OBS_SHAPE`` or ``dims.OBS_SHAPE``,
which are both rather long and take up a good chunk of the hint size.
- Decide on whether the storage modules should be plural (shp, dim, symbol) or
(shps, dims, symbols)? Singular is shorter and looks nicer, maybe?
- Placeholder should support addition as tuple concatenation. Does this already
work?
- Placeholders dims (scalars) should support arbitrary mathematical expressions
involving symbols.
- Add configuration parsing. (DONE)
- Consider just not implementing per-module typecheck enable/disable. It should
be all or nothing via either the configuration file option or the environment
variable.
- Consider implementing per-module disable via a global variable in each
module. The decorators can see the scope of the module from which they are
called, and these scopes are independent, so set ``ASTA_MODULE_DISABLE =
True``and asta will skip checking that module.
- Don't support default precisions because it decreases type transparency.
(DONE)
- Support subcriptable aliases:

>>> CudaTensor = Tensor[{"device": "cuda:0"}]
>>> FloatCudaTensor = CudaTensor[float]

Currently, subscripting ``CudaTensor`` as above will overwrite/delete the
existing subscript. So just add a class attribute that saves the subscript,
and upon subscription, combines the previous subscript with the new
subscript, raising an error if an attribute is already set. Actually, on
second though, instead of saving the subscript, just save the parsed class
attributes. (DONE)
- Bug: tuples cannot be summed without unpacking ``Tensor[shapes.OB + (1,)]``.
(DONE)
- Bug: Header should still be printed when an error is raised if
``ox.print_passes`` is off. (DONE)
- ---MERGE---(DONE)
- Update README (DONE)
- Add warning when ``ox.print_passes`` is False, so that the user doesn't
accidentally leave it on, slowing down their program, and make this happen at
import-time, so that if the part where they're calling typechecked functions
doesn't happen for a while, they still know immediately. (DONE)
- ---MERGE---
- Consider adding an ``args`` module. So if you have a function:

>>> def fn(k: int) -> Tensor[<k>]:
>>> return torch.ones((k,))

where the shape is dependent on the arguments, you could annotate it.
Syntax would look like this:

>>> def fn(k: int) -> Tensor[args.k]:
>>> return torch.ones((k,))

- Consider adding probabilistic checking, i.e. if a function has been called
100,000 times and none of them have failed, maybe stop checking it,
especially if it's taking a long time. We can do inference and estimate the
probability that the function will fail.
- Before we do the above, attempt to jit compile everything in asta.
- Consider making overwrites of class attributes like:

>>> Tensor[float][float]

raise an error saying the attribute has already been defined. Perhaps
``kwattrs`` allows overwriting as long as there are no keys in common or
something. Is an error expected behavior, or should we be lenient?
- Consider adding checks before looping over any iterable to make sure the
subscript annotation is an astatype in the case where
``ox.check_non_asta_types`` is false.
- Consider making ``@typechecked`` callable so we can call:

>>> @typechecked(print=True)
>>> def ...

So then we can have ``print-all-passes=no`` but still print the one we're
debugging.



Acknowledgements
----------------
- Based on the excellent packages 'nptyping' by Ramon Hagenaars and 'typeguard'
by Alex Grönholm.
- Thanks to Michael Crawshaw (@mtcrawshaw) for helpful comments on natural
shape constraints and handling of 0-valued dimension sizes.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

asta-0.0.7.tar.gz (44.1 kB view hashes)

Uploaded Source

Built Distribution

asta-0.0.7-py3-none-any.whl (56.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page