Skip to main content

GPT-2 implemented in PyTorch

Project description

PyTorch GPT-2

Travis Coverage

Install

pip install torch-gpt-2

Demo

import os
import sys
from torch_gpt_2 import load_trained_model_from_checkpoint, get_bpe_from_files, generate


if len(sys.argv) != 2:
    print('python3 demo.py MODEL_FOLDER')
    sys.exit(-1)


model_folder = sys.argv[1]
config_path = os.path.join(model_folder, 'hparams.json')
checkpoint_path = os.path.join(model_folder, 'model.ckpt')
encoder_path = os.path.join(model_folder, 'encoder.json')
vocab_path = os.path.join(model_folder, 'vocab.bpe')


print('Load net from checkpoint...')
net = load_trained_model_from_checkpoint(config_path, checkpoint_path)
print('Load BPE from files...')
bpe = get_bpe_from_files(encoder_path, vocab_path)
print('Generate text...')
output = generate(net, bpe, ['From the day forth, my arm'], length=20, top_k=1)

# If you are using the 117M model and top_k equals to 1, then the result would be:
# "From the day forth, my arm was broken, and I was in a state of pain. I was in a state of pain,"
print(output[0])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-gpt-2-0.3.0.tar.gz (5.9 kB view details)

Uploaded Source

File details

Details for the file torch-gpt-2-0.3.0.tar.gz.

File metadata

  • Download URL: torch-gpt-2-0.3.0.tar.gz
  • Upload date:
  • Size: 5.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.7.1 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.6.4

File hashes

Hashes for torch-gpt-2-0.3.0.tar.gz
Algorithm Hash digest
SHA256 36c99f38fb8a910a3a38f483f776d4f23c19ddfe46387145afda785cdaff2fd0
MD5 47558bbd8c54c701218030e5d18d3a8f
BLAKE2b-256 7d5cd784bde9fc5c2ed4a2d2b0d731c2e361e4acfc3d6d141629d783898a16bb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page