Skip to main content

Plot Tensor — visualize multi-dimensional arrays as faceted seaborn line plots

Project description

pt — Plot Tensor

pt is a Python utility for visualizing multi-dimensional arrays as faceted line plots. It maps each dimension of a tensor to a visual channel (colour, line width, linestyle, facet row, facet column) and renders the result using seaborn.

Supported array types: numpy.ndarray, JAX Array, Penzai NamedArray, xarray.DataArray

pip install pt  # or: uv add --editable /path/to/pt

Quick start

import numpy as np
import pt

# 1-D array — single line, time axis auto-detected
pt.line(np.random.randn(200).cumsum())

# 2-D array — first axis auto-mapped to colour, last to time
signals = np.random.randn(8, 200).cumsum(axis=-1)
pt.line(signals)

pt.line

pt.line(tensor, *, time=None, x=None,
        hue=None, color=None, color2d=None,
        style=None, size=None,
        row=None, col=None,
        dim_names=None, coords=None,
        palette=None, sizes=(0.5, 2.5), scale_linewidth_sqrt=False,
        dashes=None, height=3.0, aspect=1.5, col_wrap=None,
        alpha=0.8, legend=True, title=None, xlabel=None, ylabel=None,
        verbose=False, **kwargs) -> sns.FacetGrid

Each keyword argument maps one or more tensor axes to a visual channel. All are optional — pt.line applies smart defaults for under-specified cases.


Axis specification

Naming axes

For plain numpy / JAX arrays, axes are named dim_0, dim_1, … by default. Supply names via dim_names:

# List of strings
pt.line(x, dim_names=['batch', 'layer', 'time'])

# List of (name, coordinate_labels) tuples — names and labels together
pt.line(x, dim_names=[
    ('batch', None),                        # labels default to 0, 1, 2, …
    ('layer', ['L0', 'L1', 'L2', 'L3']),
    ('time',  np.linspace(0, 1, T)),
])

Penzai NamedArray and xarray.DataArray supply names (and xarray coordinates) automatically.

Coordinate labels

Override or supplement labels with the coords dict. Keys may be axis names or integer indices. None or () defaults to np.arange(n).

pt.line(x,
    dim_names=['batch', 'layer', 'time'],
    coords={
        'layer': ['L0', 'L1', 'L2', 'L3'],
        'time':  np.linspace(0.0, 1.0, T),
    })

Channel reference

time / x — x-axis (aliases)

The axis that becomes the x-axis of each line plot. One or neither may be specified.

  • Auto-detection: if any axis is named time, t, T, or x, it is automatically bound without needing time=.
  • Fallback: the last axis.
pt.line(x, time='t')      # explicit
pt.line(x, x='t')         # same thing
pt.line(x)                # auto-detected if a dim is named 'time'

hue / color — line colour (aliases)

Maps one or more axes to line colour. hue and color are identical; use whichever you prefer.

# Single axis → sequential palette
pt.line(x, hue='layer')

# Multiple axes → Cartesian-product, linearised onto a single palette
pt.line(x, color=['layer', 'head'])

Default palette: husl for ≤ 12 values (perceptually uniform categorical), viridis for > 12. Override with palette=:

pt.line(x, hue='layer', palette='tab10')
pt.line(x, hue='layer', palette=['#e41a1c', '#377eb8', '#4daf4a'])

color2d — 2-D colour palette

Maps exactly two axes to a 2-D HLS colour grid: the first axis varies hue across the colour wheel (0.05 → 0.85), the second varies lightness (0.35 → 0.65). This keeps both axes visually distinguishable simultaneously.

color2d is mutually exclusive with hue / color.

# head axis → hue direction, layer axis → lightness direction
pt.line(x, dim_names=['run', 'head', 'layer', 't'],
        color2d=['head', 'layer'], col='run')

A swatch-grid legend is placed on the right margin of the figure.

style — linestyle

Maps one axis to linestyle, cycling: solid → dashed → dotted → dash-dot → …

pt.line(x, hue='layer', style='condition')

# Custom dash patterns (matplotlib dash specs)
pt.line(x, style='condition',
        dashes=[(None,None), (4, 2), (1, 1)])

size — line width

Maps one or more axes to linewidth, linearly interpolated across sizes=(min, max).

pt.line(x, hue='layer', size='run', sizes=(0.5, 3.0))

# Area-proportional scaling (sqrt mode)
pt.line(x, size='run', sizes=(0.5, 3.0), scale_linewidth_sqrt=True)

row / col — facet axes

pt.line(x, hue='layer', row='batch', col='condition')

# Single faceting dimension with wrapping
pt.line(x, hue='layer', col='batch', col_wrap=4)

Unassigned axes

Any axis not mapped to a channel is mean-reduced with a UserWarning:

# 'batch' is unassigned → averaged over, warning emitted
pt.line(x, dim_names=['batch', 'layer', 'time'], hue='layer')
# UserWarning: Axes ['batch'] are not assigned to any channel and will be mean-reduced.

Pass verbose=True to print a table of how every axis is mapped before plotting:

pt.line(x, dim_names=['batch', 'layer', 'time'],
        hue='layer', row='batch', verbose=True)
dim           shape  role            coords
------------  -----  --------------  ------------------------
batch             4  row             [0, 1, 2, 3]
layer             6  hue             [0 .. 5]  (6)
time            100  x-axis          [0.00 .. 0.99]  (100)

Named array types

xarray DataArray

Dimension names and coordinate values are extracted automatically:

import xarray as xr

da = xr.DataArray(
    data,
    dims=['batch', 'layer', 'time'],
    coords={'layer': ['L0','L1','L2'], 'time': t_values},
)
pt.line(da, hue='layer', row='batch')

Penzai NamedArray

from penzai.core import named_axes as na

arr = na.NamedArray.wrap(data, ('batch', 'layer', 'time'))
pt.line(arr, hue='layer', row='batch')

JAX arrays

Converted to numpy automatically. Pass dim_names / coords to annotate axes.

import jax.numpy as jnp
pt.line(jnp.array(data), dim_names=['layer', 'time'], hue='layer')

Figure and aesthetic options

Parameter Default Description
height 3.0 Height of each facet in inches
aspect 1.5 Width-to-height ratio per facet
col_wrap None Wrap columns (only when row is not used)
alpha 0.8 Line opacity
legend True Show colour / size / style legends
title None Figure suptitle
xlabel None x-axis label (defaults to axis name)
ylabel None y-axis label (defaults to "value")
**kwargs Forwarded to ax.plot() (e.g. marker='o', linestyle='--')

Return value

pt.line returns a seaborn.FacetGrid, giving full access to the underlying figure and axes:

g = pt.line(x, hue='layer', row='batch')

g.set(xlim=(0, 100), ylim=(-5, 5))
g.set_titles(row_template='batch {row_name}')
g.figure.savefig('output.png', dpi=150, bbox_inches='tight')

Examples

Research workflow: compare activations across layers and conditions

# activations: shape (n_layers=12, n_conditions=4, n_tokens=64)
activations = model.get_activations(inputs)  # numpy array

g = pt.line(
    activations,
    dim_names=['layer', 'condition', 'token'],
    coords={
        'layer':     [f'L{i}' for i in range(12)],
        'condition': ['base', 'prefix', 'fewshot', 'finetune'],
    },
    time='token',
    hue='condition',
    row='layer',
    col_wrap=4,
    height=2.0,
    aspect=2.0,
    title='Layer activations by condition',
)

2-D colour map: heads × layers

# attention: shape (n_heads=8, n_layers=6, seq_len=128)
g = pt.line(
    attention,
    dim_names=['head', 'layer', 'position'],
    color2d=['head', 'layer'],
    time='position',
    alpha=0.6,
    sizes=(0.5, 1.5),
)

xarray with automatic coordinates

import xarray as xr

da = xr.DataArray(
    training_curves,                        # shape (runs, steps)
    dims=['run', 'step'],
    coords={
        'run':  [f'seed={s}' for s in seeds],
        'step': np.arange(n_steps) * log_interval,
    },
)

g = pt.line(da, hue='run', xlabel='Training step', ylabel='Loss')

Combining channels

# signals: (subject=10, condition=3, electrode=64, time=500)
g = pt.line(
    signals,
    dim_names=['subject', 'condition', 'electrode', 'time'],
    time='time',
    hue='condition',
    style='condition',   # redundant encoding: colour + linestyle
    size='electrode',    # thicker lines for higher electrode indices
    row='subject',
    sizes=(0.3, 2.0),
    alpha=0.5,
    verbose=True,
)

Installation

# From PyPI (when published)
pip install pt

# Editable install from local clone
uv add --editable /path/to/pt

# Optional extras
pip install pt[jax]     # JAX support
pip install pt[penzai]  # Penzai NamedArray support
pip install pt[xarray]  # xarray DataArray support

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

plot_tensor-0.1.0.tar.gz (108.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

plot_tensor-0.1.0-py3-none-any.whl (13.5 kB view details)

Uploaded Python 3

File details

Details for the file plot_tensor-0.1.0.tar.gz.

File metadata

  • Download URL: plot_tensor-0.1.0.tar.gz
  • Upload date:
  • Size: 108.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for plot_tensor-0.1.0.tar.gz
Algorithm Hash digest
SHA256 2afab8111d829d333907b93547ec1ce5d0ecd02c5a7ddcab232921ee1c94e084
MD5 f39df4a7578d64dd764d2f54e6d89e6a
BLAKE2b-256 9ca8883ba6e5783cbfc6881fb64981df56af12a3ab5a76f924a2cf02264cd5bb

See more details on using hashes here.

File details

Details for the file plot_tensor-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: plot_tensor-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 13.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for plot_tensor-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3ab91cc91a466f792109eea5ad561c856253778e4565af6cfa7c7979ddec65f8
MD5 23b04c2e93e64d1fcae1d67c3033635f
BLAKE2b-256 7a4b1a72f7cee6810732654c15d354912e66a63e14d21770e234753b9803d488

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page