Skip to main content

PyTorch implementation of Transformer from "Attention is All You Need".

Project description

Transformer-pytorch

A PyTorch implementation of Transformer from "Attention is All You Need" (https://arxiv.org/abs/1706.03762).

This repo focuses on clean, readable, and modular implementation of the paper.

screen shot 2018-09-27 at 1 49 14 pm
$ docker build --tag transformer --rm .
$ docker run --name $CONTAINER_NAME -it --gpus all --shm-size 16G --volume $(pwd):/transformer transformer bash
python main.py v1 prepare multi30k data/Multi30k/
python -m spacy download en
python -m spacy download de
python main.py v1 build-vocab data/Multi30k/train.json results/vocabs/shared_vocab.tsv --source-language de --target-language en --min-freq 3

Vocabulary 9521 lines written to results/vocabs/shared_vocab.tsv

CUDA_VISIBLE_DEVICES=0 python main.py v1 train data/Multi30k/train.json data/Multi30k/val.json results/vocabs/shared_vocab.tsv results/vocabs/shared_vocab.tsv results/runs playground
python main.py v1 prepare aihub data/AIHub/

Requirements

Usage

Prepare datasets

This repo comes with example data in data/ directory. To begin, you will need to prepare datasets with given data as follows:

$ python prepare_datasets.py --train_source=data/example/raw/src-train.txt --train_target=data/example/raw/tgt-train.txt --val_source=data/example/raw/src-val.txt --val_target=data/example/raw/tgt-val.txt --save_data_dir=data/example/processed

The example data is brought from OpenNMT-py. The data consists of parallel source (src) and target (tgt) data for training and validation. A data file contains one sentence per line with tokens separated by a space. Below are the provided example data files.

  • src-train.txt
  • tgt-train.txt
  • src-val.txt
  • tgt-val.txt

Train model

To train model, provide the train script with a path to processed data and save files as follows:

$ python train.py --data_dir=data/example/processed --save_config=checkpoints/example_config.json --save_checkpoint=checkpoints/example_model.pth --save_log=logs/example.log 

This saves model config and checkpoints to given files, respectively. You can play around with hyperparameters of the model with command line arguments. For example, add --epochs=300 to set the number of epochs to 300.

Translate

To translate a sentence in source language to target language:

$ python predict.py --source="There is an imbalance here ." --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth

Candidate 0 : Hier fehlt das Gleichgewicht .
Candidate 1 : Hier fehlt das das Gleichgewicht .
Candidate 2 : Hier fehlt das das das Gleichgewicht .

It will give you translation candidates of the given source sentence. You can adjust the number of candidates with command line argument.

Evaluate

To calculate BLEU score of a trained model:

$ python evaluate.py --save_result=logs/example_eval.txt --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth

BLEU score : 0.0007947

File description

  • models.py includes Transformer's encoder, decoder, and multi-head attention.
  • embeddings.py contains positional encoding.
  • losses.py contains label smoothing loss.
  • optimizers.py contains Noam optimizer.
  • metrics.py contains accuracy metric.
  • beam.py contains beam search.
  • datasets.py has code for loading and processing data.
  • trainer.py has code for training model.
  • prepare_datasets.py processes data.
  • train.py trains model.
  • predict.py translates given source sentence with a trained model.
  • evaluate.py calculates BLEU score of a trained model.

Reference

Author

@dreamgonfly

Deploy

python3 setup.py sdist bdist_wheel
python3 -m twine upload --repository testpypi dist/* python3 -m twine upload --repository dist/*

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformer-pytorch-0.0.1.tar.gz (14.7 kB view hashes)

Uploaded Source

Built Distribution

transformer_pytorch-0.0.1-py3-none-any.whl (20.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page