Skip to main content

JAX/Flax implimentation of 'Attention is All You Need' by Vaswani et al.

Project description

Vanilla Transformer

PyPI version

JAX/Flax implimentation of 'Attention Is All You Need' by Vaswani et al. (https://arxiv.org/abs/1706.03762)

Installation

Use the package manager pip to install the package in the following way:

pip install vanilla-transformer-jax

Usage

To use the entire Transformer model (encoder and decoder), you can use the following way:

from jax import random
from vtransformer import Transformer # imports Transformer class

model = Transformer() # model hyperparameters can be tuned, otherwise defualts mentioned in paper shall be used

prng = random.PRNGKey(42)

example_input_src = jax.random.randint(prng, (3,4), minval=0, maxval=10000)
example_input_trg = jax.random.randint(prng, (3,5), minval=0, maxval=10000)
mask = jax.array([1, 1, 1, 0, 0])

init = model.init(prng, example_input_src, example_input_trg, mask) #initializing the params of model

output = model.apply(init, example_input_src, example_input_trg, mask) # getting output

To use Encoder and Decoder seperately, you can do so in the following way:

encoding = model.encoder(init, example_input_src)  #using only the encoder
decoding = model.decoder(init, example_input_trg, encoding, mask) #using only the decoder

Contributing

This library is not perfect and can be improved in quite a few factors.

Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.

Please make sure to update tests as appropriate.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vanilla-transformer-jax-0.0.4.tar.gz (4.4 kB view details)

Uploaded Source

Built Distribution

vanilla_transformer_jax-0.0.4-py3-none-any.whl (5.0 kB view details)

Uploaded Python 3

File details

Details for the file vanilla-transformer-jax-0.0.4.tar.gz.

File metadata

  • Download URL: vanilla-transformer-jax-0.0.4.tar.gz
  • Upload date:
  • Size: 4.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/3.7.3 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.1

File hashes

Hashes for vanilla-transformer-jax-0.0.4.tar.gz
Algorithm Hash digest
SHA256 f0e09c5d00f850507dc64dc402517f553f4339b2f2aba3a6cd4f21ca5e72778d
MD5 9efc5225fe2519c2ad9f9f9a9b21ed05
BLAKE2b-256 94da354beea34817dea93dbf4c83b0b46f3d828501568f02a61d3ca80bf57dc9

See more details on using hashes here.

File details

Details for the file vanilla_transformer_jax-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: vanilla_transformer_jax-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 5.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/3.7.3 pkginfo/1.8.2 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.1

File hashes

Hashes for vanilla_transformer_jax-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 8762fd684a98a58b69a4be3770b0e4bdd59f8e3dbcfe9161a9c4f870520beb8b
MD5 1478a026279a7a98c9d7da9367b01bf9
BLAKE2b-256 ba3b763963abc597446f17128fa0a3348adb80d76c7236eae0025ccf265cba60

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page