Skip to main content

JAX/Flax implimentation of 'Attention is All You Need' by Vaswani et al.

Project description

Vanilla Transformer

PyPI version

JAX/Flax implimentation of 'Attention Is All You Need' by Vaswani et al. (https://arxiv.org/abs/1706.03762)

Installation

Use the package manager pip to install the package in the following way:

pip install vanilla-transformer-jax

Usage

To use the entire Transformer model (encoder and decoder), you can use the following way:

from jax import random
from vtransformer import Transformer # imports Transformer class

model = Transformer() # model hyperparameters can be tuned, otherwise defualts mentioned in paper shall be used

prng = random.PRNGKey(42)

example_input_src = jax.random.randint(prng, (3,4), minval=0, maxval=10000)
example_input_trg = jax.random.randint(prng, (3,5), minval=0, maxval=10000)
mask = jax.array([1, 1, 1, 0, 0])

init = model.init(prng, example_input_src, example_input_trg, mask) #initializing the params of model

output = model.apply(init, example_input_src, example_input_trg, mask) # getting output

To use Encoder and Decoder seperately, you can do so in the following way:

encoding = model.encoder(init, example_input_src)  #using only the encoder
decoding = model.decoder(init, example_input_trg, encoding, mask) #using only the decoder

Contributing

This library is not perfect and can be improved in quite a few factors.

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vanilla-transformer-jax-0.0.4.tar.gz (4.4 kB view hashes)

Uploaded Source

Built Distribution

vanilla_transformer_jax-0.0.4-py3-none-any.whl (5.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page