Skip to main content

No project description provided

Project description

Bidirectional Cross Attention

JAX implementation of lucidrains/bidirectional-cross-attention.

Installation

pip install bidirectional-cross-attention-jax

Usage

import jax
import jax.numpy as jnp
from bidirectional_cross_attention_jax import BidirectionalCrossAttention

key = jax.random.PRNGKey(0)
video = jax.random.normal(key, (1, 4096, 512))
audio = jax.random.normal(key, (1, 8192, 386))

video_mask = jnp.ones((1, 4096), dtype=jnp.bool_)
audio_mask = jnp.ones((1, 8192), dtype=jnp.bool_)

joint_cross_attn = BidirectionalCrossAttention(
    dim = 512,
    heads = 8,
    dim_head = 64,
    context_dim = 386
)

init = joint_cross_attn.init(key, video, audio)
video_out, audio_out = joint_cross_attn.apply(
    init,
    video,
    audio,
    mask = video_mask,
    context_mask = audio_mask
)

# attended output should have the same shape as input

assert video_out.shape == video.shape
assert audio_out.shape == audio.shape

Citations

@article{Hiller2024PerceivingLS,
    title   = {Perceiving Longer Sequences With Bi-Directional Cross-Attention Transformers},
    author  = {Markus Hiller and Krista A. Ehinger and Tom Drummond},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2402.12138},
    url     = {https://api.semanticscholar.org/CorpusID:267751060}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

Built Distribution

File details

Details for the file bidirectional_cross_attention_jax-0.0.2.tar.gz.

File metadata

File hashes

Hashes for bidirectional_cross_attention_jax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 0314a3dbbd8392cd23fb3d49333cde432d6e9129134dc0c6190e83f6655450fa
MD5 6ee7be29074640a990e8698115b0ad44
BLAKE2b-256 ef8795a567e185e5e2d1a8a394beb56103921243b7eee0dca4c4894747a4d977

See more details on using hashes here.

File details

Details for the file bidirectional_cross_attention_jax-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for bidirectional_cross_attention_jax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 682f0d321a3924e461c7fa22ac05cb6c564b386c65d7ce9261c91efdf5c9a668
MD5 b51e1fa942fdb6b6ac3c59f6e180c8de
BLAKE2b-256 1e98e934230779fce5e5bf62eb5adb0555699700425299b66c921e4a9bcae350

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page