Skip to main content

Rotary Embedding - Tensorflow

Project description

Rotary Embeddings - Tensorflow

A standalone library for adding rotary embeddings to transformers in Tesnorflow, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs.

My gut also tells me there is something more to rotations that can be exploited in artificial neural networks.

Note

An implemented version of Pytorch is available in this repository.

This version is written by converting to the version of Pytorch.

The three functions of rearrange, irearrange and repeat have been written due to the incompatibility of the einops library with tensorflow 2.x.

Install

$ pip install rotary-embedding-tensorflow

Usage

import tensorflow as tf
from rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

pos_emb = RotaryEmbedding(dim = 32)

# generate the rotations

freqs = pos_emb(tf.range(1024), cache_key = 1024) # cache with a key that is the sequence length, so that it does not need to recompute

# mock queries and keys

q = tf.random.normal((1, 1024, 64)) # queries - (batch, seq len, dimension of head)
k = tf.random.normal((1, 1024, 64)) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

freqs = freqs[None, ...] # expand dimension for batch dimension
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

# then do your attention with your queries (q) and keys (k)

If you do all the steps above correctly, you should see a dramatic improvement during training

Axial Rotary Embeddings

For easy use of 2d axial relative positional embedding, ie. vision transformers

import tensorflow as tf
from rotary_embedding_tensorflow import apply_rotary_emb, RotaryEmbedding, broadcat

pos_emb = RotaryEmbedding(
    dim = 32,
    freqs_for = 'pixel'
)

# queries and keys for frequencies to be rotated into

q = tf.random.normal((1, 256, 256, 64))
k = tf.random.normal((1, 256, 256, 64))

# get frequencies for each axial
# -1 to 1 has been shown to be a good choice for images and audio

freqs_h = pos_emb(tf.linspace(-1, 1, num = 256), cache_key = 256)
freqs_w = pos_emb(tf.linspace(-1, 1, num = 256), cache_key = 256)

# concat the frequencies along each axial
# broadcat function makes this easy without a bunch of expands

freqs = broadcat((freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim = -1)

# rotate in frequencies

q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

Learned Rotations

For injecting learned rotations into a network. Experiments pending

Update: doesn't seem to do anything -_-, will keep trying...

import tensorflow as tf
from tensorflow.keras import layers
from rotary_embedding_tensorflow import apply_learned_rotations

x = tf.random.normal((1, 1024, 512))

# you can only rotate in (dim // 2) values
# ex. for 512, you can only rotate in 256 values

# say you have two sets of learned rotations of 128 values each

rots1 = layers.Dense(128)(x)
rots2 = layers.Dense(128)(x)

# you rotate in 256 (128 x 2) at first

x = apply_learned_rotations(rots1, x, start_index = 0)

# then you start at index 256 and rotate in the last (128 x 2)

x = apply_learned_rotations(rots2, x, start_index = 256)

# you could also concat the rotations together and pass it in all at once

rots = tf.concat((rots1, rots2), axis = -1)

x = apply_learned_rotations(rots, x)

Citations

@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

@misc{rotary-embedding-torch,
    title   = {Rotary Embeddings - Pytorch}, 
    author  = {Phil Wang (lucidrains)},
    year    = {2021},
    url  = {https://github.com/lucidrains/rotary-embedding-torch},
    publisher = {Github},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rotary-embedding-tensorflow-0.1.1.tar.gz (5.6 kB view details)

Uploaded Source

File details

Details for the file rotary-embedding-tensorflow-0.1.1.tar.gz.

File metadata

  • Download URL: rotary-embedding-tensorflow-0.1.1.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.8.5

File hashes

Hashes for rotary-embedding-tensorflow-0.1.1.tar.gz
Algorithm Hash digest
SHA256 12c531eb50f572b176f5f228a7cbc374563a8d1cd240aebe8844fce001cdac6f
MD5 8ac5eaf693366738f01731a8be4a9004
BLAKE2b-256 e74fa0190adec998e2bc6a04bfbd27d0a6dbf01d8999e1513b623a1779532f48

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page