GPT-2 implemented in PyTorch
Project description
PyTorch GPT-2
Install
pip install torch-gpt-2
Demo
import os
import sys
from torch_gpt_2 import load_trained_model_from_checkpoint, get_bpe_from_files, generate
if len(sys.argv) != 2:
print('python3 demo.py MODEL_FOLDER')
sys.exit(-1)
model_folder = sys.argv[1]
config_path = os.path.join(model_folder, 'hparams.json')
checkpoint_path = os.path.join(model_folder, 'model.ckpt')
encoder_path = os.path.join(model_folder, 'encoder.json')
vocab_path = os.path.join(model_folder, 'vocab.bpe')
print('Load net from checkpoint...')
net = load_trained_model_from_checkpoint(config_path, checkpoint_path)
print('Load BPE from files...')
bpe = get_bpe_from_files(encoder_path, vocab_path)
print('Generate text...')
output = generate(net, bpe, ['From the day forth, my arm'], length=20, top_k=1)
# If you are using the 117M model and top_k equals to 1, then the result would be:
# "From the day forth, my arm was broken, and I was in a state of pain. I was in a state of pain,"
print(output[0])
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch-gpt-2-0.3.0.tar.gz
(5.9 kB
view details)
File details
Details for the file torch-gpt-2-0.3.0.tar.gz
.
File metadata
- Download URL: torch-gpt-2-0.3.0.tar.gz
- Upload date:
- Size: 5.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.7.1 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.6.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 36c99f38fb8a910a3a38f483f776d4f23c19ddfe46387145afda785cdaff2fd0 |
|
MD5 | 47558bbd8c54c701218030e5d18d3a8f |
|
BLAKE2b-256 | 7d5cd784bde9fc5c2ed4a2d2b0d731c2e361e4acfc3d6d141629d783898a16bb |