PyTorch implementation of Transformer from "Attention is All You Need".
Project description
Transformer-pytorch
A PyTorch implementation of Transformer from "Attention is All You Need" (https://arxiv.org/abs/1706.03762).
This repo focuses on clean, readable, and modular implementation of the paper.
$ docker build --tag transformer --rm .
$ docker run --name $CONTAINER_NAME -it --gpus all --shm-size 16G --volume $(pwd):/transformer transformer bash
python main.py v1 prepare multi30k data/Multi30k/
python -m spacy download en
python -m spacy download de
python main.py v1 build-vocab data/Multi30k/train.json results/vocabs/shared_vocab.tsv --source-language de --target-language en --min-freq 3
Vocabulary 9521 lines written to results/vocabs/shared_vocab.tsv
CUDA_VISIBLE_DEVICES=0 python main.py v1 train data/Multi30k/train.json data/Multi30k/val.json results/vocabs/shared_vocab.tsv results/vocabs/shared_vocab.tsv results/runs playground
python main.py v1 prepare aihub data/AIHub/
Requirements
- Python 3.6+
- PyTorch 4.1+
- NumPy
- NLTK
- tqdm
Usage
Prepare datasets
This repo comes with example data in data/
directory. To begin, you will need to prepare datasets with given data as follows:
$ python prepare_datasets.py --train_source=data/example/raw/src-train.txt --train_target=data/example/raw/tgt-train.txt --val_source=data/example/raw/src-val.txt --val_target=data/example/raw/tgt-val.txt --save_data_dir=data/example/processed
The example data is brought from OpenNMT-py. The data consists of parallel source (src) and target (tgt) data for training and validation. A data file contains one sentence per line with tokens separated by a space. Below are the provided example data files.
src-train.txt
tgt-train.txt
src-val.txt
tgt-val.txt
Train model
To train model, provide the train script with a path to processed data and save files as follows:
$ python train.py --data_dir=data/example/processed --save_config=checkpoints/example_config.json --save_checkpoint=checkpoints/example_model.pth --save_log=logs/example.log
This saves model config and checkpoints to given files, respectively.
You can play around with hyperparameters of the model with command line arguments.
For example, add --epochs=300
to set the number of epochs to 300.
Translate
To translate a sentence in source language to target language:
$ python predict.py --source="There is an imbalance here ." --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth
Candidate 0 : Hier fehlt das Gleichgewicht .
Candidate 1 : Hier fehlt das das Gleichgewicht .
Candidate 2 : Hier fehlt das das das Gleichgewicht .
It will give you translation candidates of the given source sentence. You can adjust the number of candidates with command line argument.
Evaluate
To calculate BLEU score of a trained model:
$ python evaluate.py --save_result=logs/example_eval.txt --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth
BLEU score : 0.0007947
File description
models.py
includes Transformer's encoder, decoder, and multi-head attention.embeddings.py
contains positional encoding.losses.py
contains label smoothing loss.optimizers.py
contains Noam optimizer.metrics.py
contains accuracy metric.beam.py
contains beam search.datasets.py
has code for loading and processing data.trainer.py
has code for training model.prepare_datasets.py
processes data.train.py
trains model.predict.py
translates given source sentence with a trained model.evaluate.py
calculates BLEU score of a trained model.
Reference
Author
Deploy
python3 setup.py sdist bdist_wheel
python3 -m twine upload --repository testpypi dist/*
python3 -m twine upload --repository dist/*
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file transformer-pytorch-0.0.1.tar.gz
.
File metadata
- Download URL: transformer-pytorch-0.0.1.tar.gz
- Upload date:
- Size: 14.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0e1defa5623fe184a9265ca71d5611650087d3f54eb513cb62ec44d948ae7b14 |
|
MD5 | 52dbfedf15bea20cb83ec7cf33872d57 |
|
BLAKE2b-256 | 5da8f887edcb3fbbdaa9482ba5497d7500545960d932600270baff26e670b4f4 |
File details
Details for the file transformer_pytorch-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: transformer_pytorch-0.0.1-py3-none-any.whl
- Upload date:
- Size: 20.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b34986f9fad422a71c8953c0cb0c7b08b811aafbb6755faa89d6066d05882f98 |
|
MD5 | aec87c49c9bcafbd5e8b366556f9d47b |
|
BLAKE2b-256 | 3aee48b50a868973b123b869e2db3092a0bc36405eb9e8cf297154a0f729f303 |