Python package to easily retrain OpenAI's GPT-2 text-generating model on new texts.
Project description
A simple Python package that wraps existing model fine-tuning and generation scripts for OpenAI GPT-2 text generation model (specifically the "small", 124M hyperparameter version). Additionally, this package allows easier generation of text, generating to a file for easy curation, allowing for prefixes to force the text to start with a given phrase.
Usage
An example for downloading the model to the local system, fineturning it on a dataset. and generating some text.
Warning: the pretrained model, and thus any finetuned model, is 500 MB!
import gpt_2_simple as gpt2
gpt2.download_gpt2() # model is saved into current directory under /models/124M/
sess = gpt2.start_tf_sess()
gpt2.finetune(sess, 'shakespeare.txt', steps=1000) # steps is max number of training steps
gpt2.generate(sess)
The generated model checkpoints are by default in /checkpoint/run1
. If you want to load a model from that folder and generate text from it:
import gpt_2_simple as gpt2
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess)
gpt2.generate(sess)
As with textgenrnn, you can generate and save text for later use (e.g. an API or a bot) by using the return_as_list
parameter.
single_text = gpt2.generate(sess, return_as_list=True)[0]
print(single_text)
You can pass a run_name
parameter to finetune
and load_gpt2
if you want to store/load multiple models in a checkpoint
folder.
NB: Restart the Python session first if you want to finetune on another dataset or load another model.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file gpt_2_simple-0.8.1.tar.gz
.
File metadata
- Download URL: gpt_2_simple-0.8.1.tar.gz
- Upload date:
- Size: 26.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0d620a17b4c4592190b637548afa0aaf26231c71a107165ee91e90929e398c11 |
|
MD5 | d6065545a8dae4e8c6c263e1e60b2450 |
|
BLAKE2b-256 | ecd5d1e9ab56bd82bc206fbd26a284115bee8101c2057a03d6ce8bcd069a1525 |