Pydantic support for Jaxtyping array annotations
Project description
Py Jaxtyping
Pydantic support for Jaxtyping array annotations
Usage
Instead of Int[np.ndarray, 'B 256 64 3']
, do PyArray[Int, int, 'B 256 64 3']
You can use it with jaxtyping
as normal, but also it will:
- Serialize to nested lists
- Validate the correct shape and datatypes from serialized lists
With pydantic
, make sure to use model_config = ConfigDict(arbitrary_types_allowed=True)
Example
from pydantic import BaseModel, ConfigDict
from py_jaxtyping import PyArray
from jaxtyping import Int
import numpy as np
class Sample(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
img: PyArray[Int, int, "W H 3"]
label: str
Sample.model_validate({
'img': np.ones((256, 64, 3)),
'label': 'car'
})
# checks out!
Sample.model_validate({
'img': np.ones((256, 64, 1)),
'label': 'car'
})
# fails: invalid dims :/
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
py-jaxtyping-0.1.1.tar.gz
(2.6 kB
view hashes)
Built Distribution
Close
Hashes for py_jaxtyping-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 641119f2e4b9d849b9c9ed62fd9443ea143c4d5452a2c562f8e8b80600a1aa82 |
|
MD5 | 4138193e75c5f3781f94099231606fbc |
|
BLAKE2b-256 | c1906563e9491a5df70a1335a0201b76e9b377a7ecf0592be3f40d839ec5365e |