Skip to main content

PyTorch implementation of Transformer from "Attention is All You Need".

Project description

Transformer-pytorch

A PyTorch implementation of Transformer from "Attention is All You Need" (https://arxiv.org/abs/1706.03762).

This repo focuses on clean, readable, and modular implementation of the paper.

screen shot 2018-09-27 at 1 49 14 pm
$ docker build --tag transformer --rm .
$ docker run --name $CONTAINER_NAME -it --gpus all --shm-size 16G --volume $(pwd):/transformer transformer bash
python main.py v1 prepare multi30k data/Multi30k/
python -m spacy download en
python -m spacy download de
python main.py v1 build-vocab data/Multi30k/train.json results/vocabs/shared_vocab.tsv --source-language de --target-language en --min-freq 3

Vocabulary 9521 lines written to results/vocabs/shared_vocab.tsv

CUDA_VISIBLE_DEVICES=0 python main.py v1 train data/Multi30k/train.json data/Multi30k/val.json results/vocabs/shared_vocab.tsv results/vocabs/shared_vocab.tsv results/runs playground
python main.py v1 prepare aihub data/AIHub/

Requirements

Usage

Prepare datasets

This repo comes with example data in data/ directory. To begin, you will need to prepare datasets with given data as follows:

$ python prepare_datasets.py --train_source=data/example/raw/src-train.txt --train_target=data/example/raw/tgt-train.txt --val_source=data/example/raw/src-val.txt --val_target=data/example/raw/tgt-val.txt --save_data_dir=data/example/processed

The example data is brought from OpenNMT-py. The data consists of parallel source (src) and target (tgt) data for training and validation. A data file contains one sentence per line with tokens separated by a space. Below are the provided example data files.

  • src-train.txt
  • tgt-train.txt
  • src-val.txt
  • tgt-val.txt

Train model

To train model, provide the train script with a path to processed data and save files as follows:

$ python train.py --data_dir=data/example/processed --save_config=checkpoints/example_config.json --save_checkpoint=checkpoints/example_model.pth --save_log=logs/example.log 

This saves model config and checkpoints to given files, respectively. You can play around with hyperparameters of the model with command line arguments. For example, add --epochs=300 to set the number of epochs to 300.

Translate

To translate a sentence in source language to target language:

$ python predict.py --source="There is an imbalance here ." --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth

Candidate 0 : Hier fehlt das Gleichgewicht .
Candidate 1 : Hier fehlt das das Gleichgewicht .
Candidate 2 : Hier fehlt das das das Gleichgewicht .

It will give you translation candidates of the given source sentence. You can adjust the number of candidates with command line argument.

Evaluate

To calculate BLEU score of a trained model:

$ python evaluate.py --save_result=logs/example_eval.txt --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth

BLEU score : 0.0007947

File description

  • models.py includes Transformer's encoder, decoder, and multi-head attention.
  • embeddings.py contains positional encoding.
  • losses.py contains label smoothing loss.
  • optimizers.py contains Noam optimizer.
  • metrics.py contains accuracy metric.
  • beam.py contains beam search.
  • datasets.py has code for loading and processing data.
  • trainer.py has code for training model.
  • prepare_datasets.py processes data.
  • train.py trains model.
  • predict.py translates given source sentence with a trained model.
  • evaluate.py calculates BLEU score of a trained model.

Reference

Author

@dreamgonfly

Deploy

python3 setup.py sdist bdist_wheel
python3 -m twine upload --repository testpypi dist/* python3 -m twine upload --repository dist/*

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformer-pytorch-0.0.1.tar.gz (14.7 kB view details)

Uploaded Source

Built Distribution

transformer_pytorch-0.0.1-py3-none-any.whl (20.5 kB view details)

Uploaded Python 3

File details

Details for the file transformer-pytorch-0.0.1.tar.gz.

File metadata

  • Download URL: transformer-pytorch-0.0.1.tar.gz
  • Upload date:
  • Size: 14.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.7.4

File hashes

Hashes for transformer-pytorch-0.0.1.tar.gz
Algorithm Hash digest
SHA256 0e1defa5623fe184a9265ca71d5611650087d3f54eb513cb62ec44d948ae7b14
MD5 52dbfedf15bea20cb83ec7cf33872d57
BLAKE2b-256 5da8f887edcb3fbbdaa9482ba5497d7500545960d932600270baff26e670b4f4

See more details on using hashes here.

File details

Details for the file transformer_pytorch-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: transformer_pytorch-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 20.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.50.2 CPython/3.7.4

File hashes

Hashes for transformer_pytorch-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b34986f9fad422a71c8953c0cb0c7b08b811aafbb6755faa89d6066d05882f98
MD5 aec87c49c9bcafbd5e8b366556f9d47b
BLAKE2b-256 3aee48b50a868973b123b869e2db3092a0bc36405eb9e8cf297154a0f729f303

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page