Skip to main content

Pydantic support for Jaxtyping array annotations

Project description

Py Jaxtyping

Pydantic support for Jaxtyping array annotations

Usage

Instead of Int[np.ndarray, 'B 256 64 3'], do PyArray[Int, int, 'B 256 64 3']

You can use it with jaxtyping as normal, but also it will:

  • Serialize to nested lists
  • Validate the correct shape and datatypes from serialized lists

Example

from pydantic import BaseModel, ConfigDict
from py_jaxtyping import PyArray
from jaxtyping import Int
import numpy as np

class Sample(BaseModel):
  model_config = ConfigDict(arbitrary_types_allowed=True)
  img: PyArray[Int, int, "W H 3"]
  label: str

Sample.model_validate({
  'img': np.ones((256, 64, 3)),
  'label': 'car'
})
# checks out!


Sample.model_validate({
  'img': np.ones((256, 64, 1)),
  'label': 'car'
})
# fails: invalid dims :/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

py-jaxtyping-0.1.4.tar.gz (2.8 kB view hashes)

Uploaded Source

Built Distribution

py_jaxtyping-0.1.4-py3-none-any.whl (3.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page