No project description provided
Project description
Bidirectional Cross Attention
JAX implementation of lucidrains/bidirectional-cross-attention.
Installation
pip install bidirectional-cross-attention-jax
Usage
import jax
import jax.numpy as jnp
from bidirectional_cross_attention_jax import BidirectionalCrossAttention
key = jax.random.PRNGKey(0)
video = jax.random.normal(key, (1, 4096, 512))
audio = jax.random.normal(key, (1, 8192, 386))
video_mask = jnp.ones((1, 4096), dtype=jnp.bool_)
audio_mask = jnp.ones((1, 8192), dtype=jnp.bool_)
joint_cross_attn = BidirectionalCrossAttention(
dim = 512,
heads = 8,
dim_head = 64,
context_dim = 386
)
init = joint_cross_attn.init(key, video, audio)
video_out, audio_out = joint_cross_attn.apply(
init,
video,
audio,
mask = video_mask,
context_mask = audio_mask
)
# attended output should have the same shape as input
assert video_out.shape == video.shape
assert audio_out.shape == audio.shape
Citations
@article{Hiller2024PerceivingLS,
title = {Perceiving Longer Sequences With Bi-Directional Cross-Attention Transformers},
author = {Markus Hiller and Krista A. Ehinger and Tom Drummond},
journal = {ArXiv},
year = {2024},
volume = {abs/2402.12138},
url = {https://api.semanticscholar.org/CorpusID:267751060}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for bidirectional_cross_attention_jax-0.0.2.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0314a3dbbd8392cd23fb3d49333cde432d6e9129134dc0c6190e83f6655450fa |
|
MD5 | 6ee7be29074640a990e8698115b0ad44 |
|
BLAKE2b-256 | ef8795a567e185e5e2d1a8a394beb56103921243b7eee0dca4c4894747a4d977 |
Close
Hashes for bidirectional_cross_attention_jax-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 682f0d321a3924e461c7fa22ac05cb6c564b386c65d7ce9261c91efdf5c9a668 |
|
MD5 | b51e1fa942fdb6b6ac3c59f6e180c8de |
|
BLAKE2b-256 | 1e98e934230779fce5e5bf62eb5adb0555699700425299b66c921e4a9bcae350 |