JAX/Flax implimentation of 'Attention is All You Need' by Vaswani et al.
Project description
Vanilla Transformer
JAX/Flax implimentation of 'Attention Is All You Need' by Vaswani et al. (https://arxiv.org/abs/1706.03762)
Installation
Use the package manager pip to install the package in the following way:
pip install vanilla-transformer-jax
Usage
To use the entire Transformer model (encoder and decoder), you can use the following way:
from jax import random
from vtransformer import Transformer # imports Transformer class
model = Transformer() # model hyperparameters can be tuned, otherwise defualts mentioned in paper shall be used
prng = random.PRNGKey(42)
example_input_src = jax.random.randint(prng, (3,4), minval=0, maxval=10000)
example_input_trg = jax.random.randint(prng, (3,5), minval=0, maxval=10000)
mask = jax.array([1, 1, 1, 0, 0])
init = model.init(prng, example_input_src, example_input_trg, mask) #initializing the params of model
output = model.apply(init, example_input_src, example_input_trg, mask) # getting output
To use Encoder and Decoder seperately, you can do so in the following way:
encoding = model.encoder(init, example_input_src) #using only the encoder
decoding = model.decoder(init, example_input_trg, encoding, mask) #using only the decoder
Contributing
This library is not perfect and can be improved in quite a few factors.
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Please make sure to update tests as appropriate.
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vanilla-transformer-jax-0.0.4.tar.gz.
File metadata
- Download URL: vanilla-transformer-jax-0.0.4.tar.gz
- Upload date:
- Size: 4.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/3.7.3 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f0e09c5d00f850507dc64dc402517f553f4339b2f2aba3a6cd4f21ca5e72778d
|
|
| MD5 |
9efc5225fe2519c2ad9f9f9a9b21ed05
|
|
| BLAKE2b-256 |
94da354beea34817dea93dbf4c83b0b46f3d828501568f02a61d3ca80bf57dc9
|
File details
Details for the file vanilla_transformer_jax-0.0.4-py3-none-any.whl.
File metadata
- Download URL: vanilla_transformer_jax-0.0.4-py3-none-any.whl
- Upload date:
- Size: 5.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/3.7.3 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8762fd684a98a58b69a4be3770b0e4bdd59f8e3dbcfe9161a9c4f870520beb8b
|
|
| MD5 |
1478a026279a7a98c9d7da9367b01bf9
|
|
| BLAKE2b-256 |
ba3b763963abc597446f17128fa0a3348adb80d76c7236eae0025ccf265cba60
|