PyTorch implementation of Rotary Spatial Embeddings
Project description
RoSE N-dimensional Rotary Spatial Embeddings
Original implementation of Rotary Spatial Embeddings (in PyTorch)
Rotary Spatial Embeddings (RoSE) extends 2D Rotary Position Embeddings (RoPE) and the original 1D RoPE to incorporate into the embeddings spatial information in terms of N-dimensional real world coordinates. This is particularly useful for tasks that require understanding of spatial relationships across different scales, such as in microscopy.
Explanation
1 Relative phase in 1-D RoPE
If you write the 1-D RoPE positional factor for token $t$ as a per-token complex phase
\phi(t)=e^{\,i\,t\theta},\qquad t\in\mathbb Z .
After you attach that phase to query $q_t$ and key $k_t$,
\tilde q_t = q_t\;\phi(t),\qquad
\tilde k_t = k_t\;\phi(t)^{*},
where $^*$ denotes complex conjugation, their dot-product inside attention becomes
\tilde q_n\,\tilde k_m^{}
\;=\; q_n\,k_m^{}\,
\underbrace{\phi(n)\,\phi(m)^{*}}_{=\,e^{\,i\,(n-m)\theta}} .
⸻
2 Extending to N dimensions
Give every token a coordinate vector $\mathbf{p}=(x,y,z,\dots)\in\mathbb R^{N}.$
Define its phase as
\phi(\mathbf{p}) \;=\;e^{\,i\,\langle\mathbf{p},\,\boldsymbol\theta\rangle},
\qquad
\langle\mathbf{p},\boldsymbol\theta\rangle
=\sum_{a=1}^{N} p_a\,\theta_a .
Then
\phi(\mathbf{p}_n)\,\phi(\mathbf{p}_m)^{*}
\;=\;
e^{\,i\,\langle\mathbf{p}_n-\mathbf{p}_m,\;\boldsymbol\theta\rangle},
which is the ND generalisation of the 1-D $e^{,i,(n-m)\theta}$. You still get
A_{nm}\;=\;\mathrm{Re}
\bigl[q_n k_m^{*}\;e^{\,i\,\langle\mathbf{p}_n-\mathbf{p}_m,
\boldsymbol\theta\rangle}\bigr],
while keeping the per-token encoding cost $O(LD)$.
3 Embedding real-world coordinates
In many applications, such as microscopy or 3D point clouds, the coordinates are not just indices but represent real-world positions that may contain useful spatial information. RoSE allows for injecting these coordinates directly into the rotary embeddings by simply multiplyin the coordinate vectors by the coordinate spacing (i.e. voxel size) before applying the rotary embedding.
Installation
From PyPI
pip install rose-spatial-embeddings
From source
pip install git+https://github.com/rhoadesScholar/RoSE.git
Usage
import torch
from RoSE import RoSELayer, RoSEMultiheadSelfAttention
# Basic RoSE layer for applying rotary spatial embeddings to q and k
layer = RoSELayer(dim=128, num_heads=8, spatial_dims=3, learnable=True)
batch_size, seq_len = 2, 1000
q = torch.randn(batch_size, seq_len, 128)
k = torch.randn(batch_size, seq_len, 128)
# Define spatial grid properties
grid_shape = (10, 10, 10) # 3D grid dimensions
voxel_size = (1.0, 1.0, 1.0) # Physical size of each voxel
# Apply rotary spatial embeddings
q_rot, k_rot = layer(q, k, grid_shape, voxel_size)
License
BSD 3-Clause License. See LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file rotary_spatial_embeddings-2025.7.31.528.tar.gz.
File metadata
- Download URL: rotary_spatial_embeddings-2025.7.31.528.tar.gz
- Upload date:
- Size: 10.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
77993a66653a95a2ba80cc18f671faa3f066880e739e3906454f9cc2cc290937
|
|
| MD5 |
4dab4d2ebb2d56174d266a9415d78fd2
|
|
| BLAKE2b-256 |
7e8eef1cd2c7a76bdff502d6461bd5f207a53e1175e01bb7ee14ff05ad1525b3
|
File details
Details for the file rotary_spatial_embeddings-2025.7.31.528-py3-none-any.whl.
File metadata
- Download URL: rotary_spatial_embeddings-2025.7.31.528-py3-none-any.whl
- Upload date:
- Size: 7.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1b2f4be1604b583cfddf3736b46d6bbb483f8553fd7e931afe9bcee53953a077
|
|
| MD5 |
b94a876ce4a7358177f69b712cf06b33
|
|
| BLAKE2b-256 |
c9ed29ff0e8c16222a0e1b6d62c10d98a45f2c394bff0638addfbe1186e853df
|