Skip to main content

Penzai: A JAX research toolkit for building, editing, and visualizing neural networks.

Project description

Penzai

盆 ("pen", tray) 栽 ("zai", planting) - an ancient Chinese art of forming trees and landscapes in miniature, also called penjing and an ancestor of the Japanese art of bonsai.

Penzai is a JAX library for writing models as legible, functional pytree data structures, along with tools for visualizing, modifying, and analyzing them. Penzai focuses on making it easy to do stuff with models after they have been trained, making it a great choice for research involving reverse-engineering or ablating model components, inspecting and probing internal activations, performing model surgery, debugging architectures, and more. (But if you just want to build and train a model, you can do that too!)

With Penzai, your neural networks could look like this:

Screenshot of the Gemma model in Penzai

Penzai is structured as a collection of modular tools, designed together but each useable independently:

  • penzai.nn (pz.nn): A declarative combinator-based neural network library and an alternative to other neural network libraries like Flax, Haiku, Keras, or Equinox, which exposes the full structure of your model's forward pass in the model pytree. This means you can see everything your model does by pretty printing it, and inject new runtime logic with jax.tree_util. Like Equinox, there's no magic: models are just callable pytrees under the hood.

  • penzai.treescope (pz.ts): A superpowered interactive Python pretty-printer, which works as a drop-in replacement for the ordinary IPython/Colab renderer. It's designed to help understand Penzai models and other deeply-nested JAX pytrees, with built-in support for visualizing arbitrary-dimensional NDArrays.

  • penzai.core.selectors (pz.select): A pytree swiss-army-knife, generalizing JAX's .at[...].set(...) syntax to arbitrary type-driven pytree traversals, and making it easy to do complex rewrites or on-the-fly patching of Penzai models and other data structures.

  • penzai.core.named_axes (pz.nx): A lightweight named axis system which lifts ordinary JAX functions to vectorize over named axes, and allows you to seamlessly switch between named and positional programming styles without having to learn a new array API.

  • penzai.data_effects (pz.de): An opt-in system for side arguments, random numbers, and state variables that is built on pytree traversal and puts you in control, without getting in the way of writing or using your model.

Documentation on Penzai can be found at https://penzai.readthedocs.io.

[!WARNING] Penzai's API is currently unstable and may change in future releases.

In particular, the way Penzai handles parameter initialization, parameter sharing, and local mutable state in penzai.nn and penzai.data_effects is likely to be simplified in the future. Some internal details of the treescope pretty-printer intermediate representation may also change to make it easier to extend and configure.

Projects that use Penzai's neural network components or model implementations, or that define their own handlers for treescope, are encouraged to pin the 0.1.x release series (e.g. penzai>=0.1,<0.2) to avoid breaking changes.

Getting Started

If you haven't already installed JAX, you should do that first, since the installation process depends on your platform. You can find instructions in the JAX documentation. Afterward, you can install Penzai using

pip install penzai

and import it using

import penzai
from penzai import pz

(penzai.pz is an alias namespace, which makes it easier to reference common Penzai objects.)

When working in an Colab or IPython notebook, we recommend also configuring Penzai as the default pretty printer, and enabling some utilities for interactive use:

pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

# Optional: enables automatic array visualization
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

Here's how you could initialize and visualize a simple neural network:

from penzai.example_models import simple_mlp
mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([8, 32, 32, 8]),
    jax.random.key(42),
)

# Models and arrays are visualized automatically when you output them from a
# Colab/IPython notebook cell:
mlp

Here's how you could capture and extract the activations after the elementwise nonlinearities:

mlp_with_captured_activations = pz.de.CollectingSideOutputs.handling(
    pz.select(mlp)
    .at_instances_of(pz.nn.Elementwise)
    .insert_after(pz.de.TellIntermediate())
)

output, intermediates = mlp_with_captured_activations(
  pz.nx.ones({"features": 8})
)

To learn more about how to build and manipulate neural networks with Penzai, we recommend starting with the "How to Think in Penzai" tutorial, or one of the other tutorials in the Penzai documentation.


This is not an officially supported Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

penzai-0.1.2.tar.gz (53.4 MB view details)

Uploaded Source

Built Distribution

penzai-0.1.2-py3-none-any.whl (339.8 kB view details)

Uploaded Python 3

File details

Details for the file penzai-0.1.2.tar.gz.

File metadata

  • Download URL: penzai-0.1.2.tar.gz
  • Upload date:
  • Size: 53.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.8

File hashes

Hashes for penzai-0.1.2.tar.gz
Algorithm Hash digest
SHA256 750aec53d74312524b6079b2c4d6f7225c9a82eec2498e16dbca32762f8aae91
MD5 ad9f1b014bcf2faf9f366ea0403654f7
BLAKE2b-256 f1be113157a84a5c39dc0268d5bd1c5ad467da0960849ee00b801db0b0a0d1bd

See more details on using hashes here.

File details

Details for the file penzai-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: penzai-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 339.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.8

File hashes

Hashes for penzai-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 33166b969d910eba45d07acfdcd5091845b0bbd2d8c68d04df86d67a1a14d2d0
MD5 742f33b1c4c2f62dc1011e58c662988c
BLAKE2b-256 d0fdae5a6eadbce30583e5cc89f17d431bae191b6ef2432c0935aca64696fbf3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page