JAX/Flax implimentation of 'Attention is All You Need' by Vaswani et al.
Project description
Vanilla Transformer
JAX/Flax implimentation of 'Attention Is All You Need' by Vaswani et al. (https://arxiv.org/abs/1706.03762)
Installation
Use the package manager pip to install the package in the following way:
pip install vanilla-transformer-jax
Usage
To use the entire Transformer model (encoder and decoder), you can use the following way:
from jax import random
from vtransformer import Transformer # imports Transformer class
model = Transformer() # model hyperparameters can be tuned, otherwise defualts mentioned in paper shall be used
prng = random.PRNGKey(42)
example_input_src = jax.random.randint(prng, (3,4), minval=0, maxval=10000)
example_input_trg = jax.random.randint(prng, (3,5), minval=0, maxval=10000)
mask = jax.array([1, 1, 1, 0, 0])
init = model.init(prng, example_input_src, example_input_trg, mask) #initializing the params of model
output = model.apply(init, example_input_src, example_input_trg, mask) # getting output
To use Encoder and Decoder seperately, you can do so in the following way:
encoding = model.encoder(init, example_input_src) #using only the encoder
decoding = model.decoder(init, example_input_trg, encoding, mask) #using only the decoder
Contributing
This library is not perfect and can be improved in quite a few factors.
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Please make sure to update tests as appropriate.
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for vanilla-transformer-jax-0.0.4b0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | a9f25af81b1334faf2403e74ca2d2a7b98acf4d660481438031ee3b198092cc7 |
|
MD5 | 5b9461e52db80db3b83d619016499091 |
|
BLAKE2b-256 | 534e171f4035d1bf3225ef42f14b0a28235ff38f4ea5be250abb5c5fa6f48ffb |
Hashes for vanilla_transformer_jax-0.0.4b0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1bfa7539d3a89b932a110f8d8c26f2e4ce81e8b2f5a13090df7267668c0b12ff |
|
MD5 | c49d8c58ec62afe89aaacddcfe72a56e |
|
BLAKE2b-256 | fde804b5652696903d640e2875065626b4ac758dc35d5bcb927293bd16b9bd8c |