Skip to main content

Unified 3D transform utilities supporting both NumPy and PyTorch backends

Project description

uni-transform

A Python library for 3D rigid body transformations with NumPy and PyTorch backends.

Python 3.8+ License: MIT

Key Features:

  • Dual API - Transform for poses (rotation + translation), Rotation for pure rotations
  • Batch Operations - All functions support arbitrary batch dimensions (..., N)
  • PyTorch Gradients - Fully differentiable for deep learning
  • Dual Backend - Seamless NumPy ↔ PyTorch switching
  • Extra State - Store gripper width, joint states alongside poses
  • Unit Support - Explicit translation units (meters/millimeters)
  • Visualization - 3D visualization with Rerun (optional)

Installation

uv add uni-transform

From source:

git clone https://github.com/junhaotu/uni-transform.git
cd uni-transform
uv pip install -e .

Quick Start

Transform Class (Rotation + Translation)

import numpy as np
from uni_transform import Transform

# Create from euler angles [x, y, z, roll, pitch, yaw]
tf = Transform.from_rep(np.array([1.0, 2.0, 3.0, 0.1, 0.2, 0.3]), from_rep="euler")

# Convert representations
quat_rep = tf.to_rep("quat")      # [x, y, z, qx, qy, qz, qw]
matrix = tf.as_matrix()            # 4x4 homogeneous matrix

# Compose & transform
tf_composed = tf @ tf.inverse()
points = tf.transform_point(np.array([[1, 0, 0], [0, 1, 0]]))

Rotation Class (Pure Rotation)

from uni_transform import Rotation

# Create from various representations
rot = Rotation.from_euler(np.array([0.1, 0.2, 0.3]), seq="ZYX")
rot = Rotation.from_quat(np.array([0, 0, 0.707, 0.707]))
rot = Rotation.from_rotvec(np.array([0, 0, np.pi/2]))
rot = Rotation.from_rotation_6d(np.array([1, 0, 0, 0, 1, 0]))

# Or use generic from_rep
rot = Rotation.from_rep(np.array([0.1, 0.2, 0.3]), from_rep="euler", seq="ZYX")

# Convert to any representation
quat = rot.to_rep("quat")
euler = rot.as_euler(seq="ZYX", degrees=True)
matrix = rot.as_matrix()

# Compose rotations
combined = rot1 @ rot2

# Apply to vectors
rotated = rot.apply(np.array([1.0, 0.0, 0.0]))

# Interpolation
rot_mid = rot1.slerp(rot2, t=0.5)

# Relative rotation
rot_rel = rot2.relative_to(rot1)  # rot1 @ rot_rel = rot2

With Extra State (Gripper Width, etc.)

# Robot pose with gripper: [x, y, z, qx, qy, qz, qw, gripper_width]
pose = np.array([0.5, 0.2, 0.1, 0, 0, 0, 1, 0.04])
tf = Transform.from_rep(pose, from_rep="quat", extra_dims=1)

tf.translation  # [0.5, 0.2, 0.1]
tf.extra        # [0.04] - gripper width preserved

# Extra is included in conversions
euler_pose = tf.to_rep("euler")  # [x, y, z, r, p, y, gripper]

# Extra is interpolated too
tf_mid = interpolate_transform(tf_start, tf_end, t=0.5)
tf_mid.extra  # Linearly interpolated gripper width

Translation Units

# Explicit unit tracking prevents accidental mixing
tf_m = Transform.from_pos_quat([1.0, 2.0, 3.0], [0, 0, 0, 1], translation_unit="m")
tf_mm = tf_m.to_unit("mm")  # Translation becomes [1000, 2000, 3000]

# Unit mismatch raises error (catches bugs early)
tf_mm @ tf_m  # Raises UnitMismatchError

# Stack also requires matching units
Transform.stack([tf_m, tf_m.clone()])  # OK
Transform.stack([tf_m, tf_mm])         # Raises UnitMismatchError

Rotation Conversions (Functional API)

from uni_transform import quaternion_to_matrix, matrix_to_euler, convert_rotation

# Direct conversions
quat = np.array([0, 0, 0.707, 0.707])  # xyzw format
matrix = quaternion_to_matrix(quat)
euler = matrix_to_euler(matrix, seq="ZYX")

# Generic conversion
euler = convert_rotation(quat, from_rep="quat", to_rep="euler", seq="ZYX")

# Static method on classes (no instance needed)
euler = Rotation.convert(quat, from_rep="quat", to_rep="euler", seq="ZYX")
pose_quat = Transform.convert(pose_euler, from_rep="euler", to_rep="quat")

PyTorch Gradients

import torch
from uni_transform import Transform, Rotation, geodesic_distance

# Transform with gradients
pred = Transform.from_rep(torch.randn(100, 9), from_rep="rotation_6d", requires_grad=True)
target = Transform.from_rep(torch.randn(100, 9), from_rep="rotation_6d")

loss = geodesic_distance(pred.rotation, target.rotation).mean()
loss.backward()  # Fully differentiable

# Rotation with gradients
rot = Rotation.from_euler(torch.tensor([0.1, 0.2, 0.3], requires_grad=True))
rotated = rot.apply(torch.tensor([1.0, 0.0, 0.0]))
rotated.sum().backward()

Batch Dimensions

# All operations support arbitrary batch dimensions
batch_tf = Transform.from_rep(np.random.randn(10, 50, 7), from_rep="quat")  # (10, 50) batch
batch_tf.rotation.shape   # (10, 50, 3, 3)
batch_tf.translation.shape  # (10, 50, 3)

# Same for Rotation
batch_rot = Rotation.from_rep(np.random.randn(10, 50, 4), from_rep="quat")
batch_rot.batch_shape  # (10, 50)

# Compose batched transforms
result = batch_tf @ batch_tf.inverse()  # Broadcasting supported

Interpolation

from uni_transform import (
    interpolate, interpolate_sequence,
    interpolate_transform, interpolate_transform_sequence,
    compute_spline, quaternion_slerp
)

# Vector interpolation
pos = interpolate(start, end, t=0.5)  # Linear
pos = interpolate(start, end, t=0.5, method="minimum_jerk", duration=2.0)  # Smooth

# Multi-point vector interpolation
positions = interpolate_sequence(waypoints, times, query_times, method="cubic_spline")

# Transform interpolation (rotation + translation)
tf_mid = interpolate_transform(tf_start, tf_end, t=0.5)
tf_mid = interpolate_transform(tf_start, tf_end, t=0.5,
    rotation_method="nlerp",           # "slerp" or "nlerp"
    translation_method="minimum_jerk", # "linear" or "minimum_jerk"
    duration=2.0
)

# Multi-point transform interpolation
keyframes = Transform.stack([tf0, tf1, tf2, tf3])
times = np.array([0.0, 1.0, 2.0, 3.0])
query_times = np.array([0.5, 1.5, 2.5])
result = interpolate_transform_sequence(keyframes, times, query_times,
    rotation_method="squad",           # "slerp", "nlerp", or "squad" (smooth)
    translation_method="cubic_spline"  # "linear", "minimum_jerk", or "cubic_spline"
)

# Rotation interpolation (via class methods)
rot_mid = rot_start.slerp(rot_end, t=0.5)  # Spherical interpolation
rot_mid = rot_start.nlerp(rot_end, t=0.5)  # Faster, approximate

# Reusable spline (compute once, evaluate many times)
spline = compute_spline(waypoints, times)
positions = spline.evaluate(query_times)
velocities = spline.derivative(query_times, order=1)

# Low-level quaternion SLERP
q_mid = quaternion_slerp(q0, q1, t=0.5)

Relative Transforms & Deltas

# Express tf in reference frame
tf_in_ref = tf.relative_to(reference_tf)  # = reference_tf.inverse() @ tf

# Apply incremental delta
tf_new = tf.apply_delta(delta_tf, in_world_frame=True)   # = delta_tf @ tf
tf_new = tf.apply_delta(delta_tf, in_world_frame=False)  # = tf @ delta_tf

# Same for Rotation
rot_rel = rot.relative_to(ref_rot)
rot_new = rot.apply_delta(delta_rot, in_body_frame=False)  # World frame
rot_new = rot.apply_delta(delta_rot, in_body_frame=True)   # Body frame

Visualization (Rerun)

Visualize transforms, trajectories, and point clouds with Rerun. Requires pip install rerun-sdk.

from uni_transform import Transform
from uni_transform.visualization import RerunVisualizer, log_transform, log_trajectory

# High-level API
viz = RerunVisualizer("my_app", spawn=True)  # spawn=True for GUI
viz.show_transform("robot/base", base_tf)
viz.show_trajectory("robot/path", trajectory)
viz.show_points("scene/cloud", points)

Rotation Representations

Name Shape Description
matrix (..., 3, 3) SO(3) rotation matrix
quat (..., 4) Quaternion (xyzw format)
euler (..., 3) Euler angles (default: ZYX)
rotation_6d (..., 6) 6D continuous rotation
rot_vec (..., 3) Rotation vector (axis × angle)

... = arbitrary batch dimensions, e.g. (B, T, 3, 3) for batched trajectories

Transform Class API

Factory Methods

Transform.identity(backend="numpy", extra_dims=0, translation_unit="m")
Transform.from_matrix(matrix_4x4, translation_unit="m")
Transform.from_rep(data, from_rep="quat", extra_dims=0, translation_unit="m")
Transform.from_pos_quat(position, quaternion, extra=None, translation_unit="m")
Transform.stack(transforms, axis=0)  # Requires matching units
Transform.convert(data, from_rep, to_rep, ...)  # Static conversion

Properties

tf.rotation          # (..., 3, 3) rotation matrix
tf.translation       # (..., 3) translation vector
tf.extra             # (..., K) extra state or None
tf.translation_unit  # TranslationUnit.METER or MILLIMETER
tf.backend           # "numpy" or "torch"
tf.batch_shape       # tuple of batch dimensions
tf.extra_dims        # number of extra dimensions (0 if None)

Methods

tf.as_matrix()           # 4x4 homogeneous matrix
tf.to_rep("quat")        # Convert to [translation, rotation, extra]
tf.to_unit("mm")         # Convert translation unit
tf.inverse()             # Inverse transform
tf.transform_point(p)    # Apply to point(s)
tf.transform_vector(v)   # Apply rotation only
tf.apply_delta(delta)    # Compose with delta
tf.relative_to(ref)      # Express in reference frame
tf.clone()               # Deep copy
tf.to(device, dtype)     # Move to device (PyTorch)
tf.detach()              # Detach from graph (PyTorch)

Rotation Class API

Factory Methods

Rotation.identity(backend="numpy")
Rotation.from_matrix(matrix_3x3)
Rotation.from_rep(data, from_rep="quat", seq="ZYX", degrees=False)
Rotation.from_quat(quaternion)           # xyzw format
Rotation.from_euler(euler, seq="ZYX")
Rotation.from_rotvec(rotation_vector)
Rotation.from_rotation_6d(rot_6d)
Rotation.stack(rotations, axis=0)
Rotation.convert(data, from_rep, to_rep, ...)  # Static conversion

Properties

rot.matrix       # (..., 3, 3) rotation matrix
rot.backend      # "numpy" or "torch"
rot.batch_shape  # tuple of batch dimensions
rot.is_batched   # bool

Methods

rot.to_rep("euler", seq="ZYX")   # Convert to representation
rot.as_matrix()                   # (..., 3, 3)
rot.as_quat()                     # (..., 4) xyzw
rot.as_euler(seq="ZYX")           # (..., 3)
rot.as_rotvec()                   # (..., 3)
rot.as_rotation_6d()              # (..., 6)
rot.inverse()                     # Inverse rotation
rot.apply(vectors)                # Rotate vectors
rot.apply_delta(delta)            # Compose
rot.relative_to(ref)              # Relative rotation
rot.slerp(other, t)               # Spherical interpolation
rot.nlerp(other, t)               # Normalized linear interpolation
rot.clone()                       # Deep copy
rot.to(device, dtype)             # Move to device (PyTorch)
rot.detach()                      # Detach from graph (PyTorch)

Key Functions

# Conversions
quaternion_to_matrix, matrix_to_quaternion
euler_to_matrix, matrix_to_euler, matrix_to_euler_differentiable
rotvec_to_matrix, matrix_to_rotvec
quaternion_to_rotvec, rotvec_to_quaternion
rotation_6d_to_matrix, matrix_to_rotation_6d
convert_rotation, rotation_to_matrix, matrix_to_rotation

# Quaternion ops
quaternion_multiply, quaternion_apply, quaternion_inverse, quaternion_conjugate

# Distance (for loss functions)
geodesic_distance, translation_distance, transform_distance

# SE(3) Lie group
se3_log, se3_exp

# Interpolation (unified API)
interpolate, interpolate_sequence                    # Vector/scalar
interpolate_rotation, interpolate_rotation_sequence # Rotation
interpolate_transform, interpolate_transform_sequence # Transform
compute_spline, SplineCoefficients                   # Reusable spline

# Interpolation (low-level)
quaternion_slerp, quaternion_nlerp, quaternion_squad
minimum_jerk_interpolate, minimum_jerk_velocity, minimum_jerk_acceleration
cubic_spline_coefficients, cubic_spline_interpolate, cubic_spline_derivative

# Utilities
orthogonalize_rotation, xyz_rotation_6d_to_matrix

# Visualization (requires rerun-sdk)
from uni_transform.visualization import (
    init, connect, save,                    # Setup
    log_transform, log_trajectory,          # Core logging
    log_trajectory_animated,                # Animation
    log_transform_manager,                  # TransformManager graph
    log_points, log_text, log_scalar,       # Primitives
    set_time,                               # Timeline control
    RerunVisualizer,                        # High-level API
)

Conventions

  • Quaternion: xyzw format (matches SciPy/ROS)
  • Euler default: ZYX sequence (yaw-pitch-roll)
  • Composition: tf1 @ tf2 applies tf2 first, then tf1
  • Translation unit: Default is meters ("m"), also supports millimeters ("mm")
  • Extra: Preserved through inverse(), clone(), to_unit(); interpolated in interpolate_transform()

Module Structure

uni_transform/
├── __init__.py              # Public API exports
├── _core.py                 # Types, constants, backend utilities
├── rotation_conversions.py  # All rotation conversion functions
├── rotation.py              # Rotation class
├── transform.py             # Transform class
├── interpolation.py         # Interpolation functions
├── se3.py                   # SE(3) Lie group operations
├── metrics.py               # Distance functions
└── visualization.py         # Rerun visualization (optional)

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uni_transform-0.2.1.tar.gz (53.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

uni_transform-0.2.1-py3-none-any.whl (42.2 kB view details)

Uploaded Python 3

File details

Details for the file uni_transform-0.2.1.tar.gz.

File metadata

  • Download URL: uni_transform-0.2.1.tar.gz
  • Upload date:
  • Size: 53.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for uni_transform-0.2.1.tar.gz
Algorithm Hash digest
SHA256 2c98758c25de379d91af934e1c51dbf199c08cafc0a04935b5430519ad7fea0f
MD5 7490b9f05afd2d5b55b4297b9c92e153
BLAKE2b-256 f4e34fdee76451c03e1c3c53a677806281ac027e010de7f7d68e0c075921468c

See more details on using hashes here.

File details

Details for the file uni_transform-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: uni_transform-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 42.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for uni_transform-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c3cc590bc5e78469cd5c32f3c8c022fddf7a00d7549b65d353acd61eaac4c24d
MD5 8e8611a48ddaeed0f073c97aae3116eb
BLAKE2b-256 c68714d73e29d09bc7c38fc3b5b5f53e1eef2d070ee57f18a132c4b8bb200377

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page