Skip to main content

Python package to easily retrain OpenAI's GPT-2 text-generating model on new texts.

Project description

A simple Python package that wraps existing model fine-tuning and generation scripts for OpenAI GPT-2 text generation model (specifically the "small", 124M hyperparameter version). Additionally, this package allows easier generation of text, generating to a file for easy curation, allowing for prefixes to force the text to start with a given phrase.

Usage

An example for downloading the model to the local system, fineturning it on a dataset. and generating some text.

Warning: the pretrained model, and thus any finetuned model, is 500 MB!

import gpt_2_simple as gpt2

gpt2.download_gpt2()   # model is saved into current directory under /models/124M/

sess = gpt2.start_tf_sess()
gpt2.finetune(sess, 'shakespeare.txt', steps=1000)   # steps is max number of training steps

gpt2.generate(sess)

The generated model checkpoints are by default in /checkpoint/run1. If you want to load a model from that folder and generate text from it:

import gpt_2_simple as gpt2

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess)

gpt2.generate(sess)

As with textgenrnn, you can generate and save text for later use (e.g. an API or a bot) by using the return_as_list parameter.

single_text = gpt2.generate(sess, return_as_list=True)[0]
print(single_text)

You can pass a run_name parameter to finetune and load_gpt2 if you want to store/load multiple models in a checkpoint folder.

NB: Restart the Python session first if you want to finetune on another dataset or load another model.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gpt_2_simple-0.8.1.tar.gz (26.5 kB view details)

Uploaded Source

File details

Details for the file gpt_2_simple-0.8.1.tar.gz.

File metadata

  • Download URL: gpt_2_simple-0.8.1.tar.gz
  • Upload date:
  • Size: 26.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.5

File hashes

Hashes for gpt_2_simple-0.8.1.tar.gz
Algorithm Hash digest
SHA256 0d620a17b4c4592190b637548afa0aaf26231c71a107165ee91e90929e398c11
MD5 d6065545a8dae4e8c6c263e1e60b2450
BLAKE2b-256 ecd5d1e9ab56bd82bc206fbd26a284115bee8101c2057a03d6ce8bcd069a1525

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page