Plot Tensor — visualize multi-dimensional arrays as faceted seaborn line plots
Project description
pt — Plot Tensor
pt is a Python utility for visualizing multi-dimensional arrays as faceted line plots. It maps each dimension of a tensor to a visual channel (colour, line width, linestyle, facet row, facet column) and renders the result using seaborn.
Supported array types: numpy.ndarray, JAX Array, Penzai NamedArray, xarray.DataArray
pip install pt # or: uv add --editable /path/to/pt
Quick start
import numpy as np
import pt
# 1-D array — single line, time axis auto-detected
pt.line(np.random.randn(200).cumsum())
# 2-D array — first axis auto-mapped to colour, last to time
signals = np.random.randn(8, 200).cumsum(axis=-1)
pt.line(signals)
pt.line
pt.line(tensor, *, time=None, x=None,
hue=None, color=None, color2d=None,
style=None, size=None,
row=None, col=None,
dim_names=None, coords=None,
palette=None, sizes=(0.5, 2.5), scale_linewidth_sqrt=False,
dashes=None, height=3.0, aspect=1.5, col_wrap=None,
alpha=0.8, legend=True, title=None, xlabel=None, ylabel=None,
verbose=False, **kwargs) -> sns.FacetGrid
Each keyword argument maps one or more tensor axes to a visual channel. All are optional — pt.line applies smart defaults for under-specified cases.
Axis specification
Naming axes
For plain numpy / JAX arrays, axes are named dim_0, dim_1, … by default. Supply names via dim_names:
# List of strings
pt.line(x, dim_names=['batch', 'layer', 'time'])
# List of (name, coordinate_labels) tuples — names and labels together
pt.line(x, dim_names=[
('batch', None), # labels default to 0, 1, 2, …
('layer', ['L0', 'L1', 'L2', 'L3']),
('time', np.linspace(0, 1, T)),
])
Penzai NamedArray and xarray.DataArray supply names (and xarray coordinates) automatically.
Coordinate labels
Override or supplement labels with the coords dict. Keys may be axis names or integer indices. None or () defaults to np.arange(n).
pt.line(x,
dim_names=['batch', 'layer', 'time'],
coords={
'layer': ['L0', 'L1', 'L2', 'L3'],
'time': np.linspace(0.0, 1.0, T),
})
Channel reference
time / x — x-axis (aliases)
The axis that becomes the x-axis of each line plot. One or neither may be specified.
- Auto-detection: if any axis is named
time,t,T, orx, it is automatically bound without needingtime=. - Fallback: the last axis.
pt.line(x, time='t') # explicit
pt.line(x, x='t') # same thing
pt.line(x) # auto-detected if a dim is named 'time'
hue / color — line colour (aliases)
Maps one or more axes to line colour. hue and color are identical; use whichever you prefer.
# Single axis → sequential palette
pt.line(x, hue='layer')
# Multiple axes → Cartesian-product, linearised onto a single palette
pt.line(x, color=['layer', 'head'])
Default palette: husl for ≤ 12 values (perceptually uniform categorical), viridis for > 12. Override with palette=:
pt.line(x, hue='layer', palette='tab10')
pt.line(x, hue='layer', palette=['#e41a1c', '#377eb8', '#4daf4a'])
color2d — 2-D colour palette
Maps exactly two axes to a 2-D HLS colour grid: the first axis varies hue across the colour wheel (0.05 → 0.85), the second varies lightness (0.35 → 0.65). This keeps both axes visually distinguishable simultaneously.
color2d is mutually exclusive with hue / color.
# head axis → hue direction, layer axis → lightness direction
pt.line(x, dim_names=['run', 'head', 'layer', 't'],
color2d=['head', 'layer'], col='run')
A swatch-grid legend is placed on the right margin of the figure.
style — linestyle
Maps one axis to linestyle, cycling: solid → dashed → dotted → dash-dot → …
pt.line(x, hue='layer', style='condition')
# Custom dash patterns (matplotlib dash specs)
pt.line(x, style='condition',
dashes=[(None,None), (4, 2), (1, 1)])
size — line width
Maps one or more axes to linewidth, linearly interpolated across sizes=(min, max).
pt.line(x, hue='layer', size='run', sizes=(0.5, 3.0))
# Area-proportional scaling (sqrt mode)
pt.line(x, size='run', sizes=(0.5, 3.0), scale_linewidth_sqrt=True)
row / col — facet axes
pt.line(x, hue='layer', row='batch', col='condition')
# Single faceting dimension with wrapping
pt.line(x, hue='layer', col='batch', col_wrap=4)
Unassigned axes
Any axis not mapped to a channel is mean-reduced with a UserWarning:
# 'batch' is unassigned → averaged over, warning emitted
pt.line(x, dim_names=['batch', 'layer', 'time'], hue='layer')
# UserWarning: Axes ['batch'] are not assigned to any channel and will be mean-reduced.
Pass verbose=True to print a table of how every axis is mapped before plotting:
pt.line(x, dim_names=['batch', 'layer', 'time'],
hue='layer', row='batch', verbose=True)
dim shape role coords
------------ ----- -------------- ------------------------
batch 4 row [0, 1, 2, 3]
layer 6 hue [0 .. 5] (6)
time 100 x-axis [0.00 .. 0.99] (100)
Named array types
xarray DataArray
Dimension names and coordinate values are extracted automatically:
import xarray as xr
da = xr.DataArray(
data,
dims=['batch', 'layer', 'time'],
coords={'layer': ['L0','L1','L2'], 'time': t_values},
)
pt.line(da, hue='layer', row='batch')
Penzai NamedArray
from penzai.core import named_axes as na
arr = na.NamedArray.wrap(data, ('batch', 'layer', 'time'))
pt.line(arr, hue='layer', row='batch')
JAX arrays
Converted to numpy automatically. Pass dim_names / coords to annotate axes.
import jax.numpy as jnp
pt.line(jnp.array(data), dim_names=['layer', 'time'], hue='layer')
Figure and aesthetic options
| Parameter | Default | Description |
|---|---|---|
height |
3.0 |
Height of each facet in inches |
aspect |
1.5 |
Width-to-height ratio per facet |
col_wrap |
None |
Wrap columns (only when row is not used) |
alpha |
0.8 |
Line opacity |
legend |
True |
Show colour / size / style legends |
title |
None |
Figure suptitle |
xlabel |
None |
x-axis label (defaults to axis name) |
ylabel |
None |
y-axis label (defaults to "value") |
**kwargs |
Forwarded to ax.plot() (e.g. marker='o', linestyle='--') |
Return value
pt.line returns a seaborn.FacetGrid, giving full access to the underlying figure and axes:
g = pt.line(x, hue='layer', row='batch')
g.set(xlim=(0, 100), ylim=(-5, 5))
g.set_titles(row_template='batch {row_name}')
g.figure.savefig('output.png', dpi=150, bbox_inches='tight')
Examples
Research workflow: compare activations across layers and conditions
# activations: shape (n_layers=12, n_conditions=4, n_tokens=64)
activations = model.get_activations(inputs) # numpy array
g = pt.line(
activations,
dim_names=['layer', 'condition', 'token'],
coords={
'layer': [f'L{i}' for i in range(12)],
'condition': ['base', 'prefix', 'fewshot', 'finetune'],
},
time='token',
hue='condition',
row='layer',
col_wrap=4,
height=2.0,
aspect=2.0,
title='Layer activations by condition',
)
2-D colour map: heads × layers
# attention: shape (n_heads=8, n_layers=6, seq_len=128)
g = pt.line(
attention,
dim_names=['head', 'layer', 'position'],
color2d=['head', 'layer'],
time='position',
alpha=0.6,
sizes=(0.5, 1.5),
)
xarray with automatic coordinates
import xarray as xr
da = xr.DataArray(
training_curves, # shape (runs, steps)
dims=['run', 'step'],
coords={
'run': [f'seed={s}' for s in seeds],
'step': np.arange(n_steps) * log_interval,
},
)
g = pt.line(da, hue='run', xlabel='Training step', ylabel='Loss')
Combining channels
# signals: (subject=10, condition=3, electrode=64, time=500)
g = pt.line(
signals,
dim_names=['subject', 'condition', 'electrode', 'time'],
time='time',
hue='condition',
style='condition', # redundant encoding: colour + linestyle
size='electrode', # thicker lines for higher electrode indices
row='subject',
sizes=(0.3, 2.0),
alpha=0.5,
verbose=True,
)
Installation
# From PyPI (when published)
pip install pt
# Editable install from local clone
uv add --editable /path/to/pt
# Optional extras
pip install pt[jax] # JAX support
pip install pt[penzai] # Penzai NamedArray support
pip install pt[xarray] # xarray DataArray support
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file plot_tensor-0.1.0.tar.gz.
File metadata
- Download URL: plot_tensor-0.1.0.tar.gz
- Upload date:
- Size: 108.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2afab8111d829d333907b93547ec1ce5d0ecd02c5a7ddcab232921ee1c94e084
|
|
| MD5 |
f39df4a7578d64dd764d2f54e6d89e6a
|
|
| BLAKE2b-256 |
9ca8883ba6e5783cbfc6881fb64981df56af12a3ab5a76f924a2cf02264cd5bb
|
File details
Details for the file plot_tensor-0.1.0-py3-none-any.whl.
File metadata
- Download URL: plot_tensor-0.1.0-py3-none-any.whl
- Upload date:
- Size: 13.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3ab91cc91a466f792109eea5ad561c856253778e4565af6cfa7c7979ddec65f8
|
|
| MD5 |
23b04c2e93e64d1fcae1d67c3033635f
|
|
| BLAKE2b-256 |
7a4b1a72f7cee6810732654c15d354912e66a63e14d21770e234753b9803d488
|