Skip to main content

Easy-to-use Wrapper for the GPT-2 117M, 345M, 774M, and 1.5B Transformer Models

Project description

gpt2-client

Easy-to-use Wrapper for GPT-2 117M, 345M, 774M, and 1.5B Transformer Models

Pypi package GitHub license GitHub license

What is itInstallationGetting Started

Made by Rishabh Anand • https://rish-16.github.io

What is it

GPT-2 is a Natural Language Processing model developed by OpenAI for text generation. It is the successor to the GPT (Generative Pre-trained Transformer) model trained on 40GB of text from the internet. It features a Transformer model that was brought to light by the Attention Is All You Need paper in 2017. The model has 4 versions - 117M, 345M, 774M, and 1558M - that differ in terms of the amount of training data fed to it and the number of parameters they contain.

The 1.5B model is currently the largest available model released by OpenAI.

Finally, gpt2-client is a wrapper around the original gpt-2 repository that features the same functionality but with more accessiblity, comprehensibility, and utilty. You can play around with all three GPT-2 models in less than five lines of code.

Note: This client wrapper is in no way liable to any damage caused directly or indirectly. Any names, places, and objects referenced by the model are fictional and seek no resemblance to real life entities or organisations. Samples are unfiltered and may contain offensive content. User discretion advised.

Installation

Install client via pip. Ideally, gpt2-client is well supported for Python >= 3.5 and TensorFlow >= 1.X. Some libraries may need to be reinstalled or upgraded using the --upgrade flag via pip if Python 2.X is used.

pip install gpt2-client

Note: gpt2-client is not compatible with TensorFlow 2.0

Getting started

1. Download the model weights and checkpoints

from gpt2_client import GPT2Client

gpt2 = GPT2Client('117M') # This could also be `345M`, `774M`, or `1558M`. Rename `save_dir` to anything.
gpt2.load_model(force_download=False) # Use cached versions if available.

This creates a directory called models in the current working directory and downloads the weights, checkpoints, model JSON, and hyper-parameters required by the model. Once you have called the load_model() function, you need not call it again assuming that the files have finished downloading in the models directory.

Note: Set force_download=True to overwrite the existing cached model weights and checkpoints

2. Start generating text!

from gpt2_client import GPT2Client

gpt2 = GPT2Client('117M') # This could also be `345M`, `774M`, or `1558M`

gpt2.generate(interactive=True) # Asks user for prompt
gpt2.generate(n_samples=4) # Generates 4 pieces of text
text = gpt2.generate(return_text=True) # Generates text and returns it in an array
gpt2.generate(interactive=True, n_samples=3) # A different prompt each time

You can see from the aforementioned sample that the generation options are highly flexible. You can mix and match based on what kind of text you need generated, be it multiple chunks or one at a time with prompts.

3. Generating text from batch of prompts

from gpt2_client import GPT2Client

gpt2 = GPT2Client('117M') # This could also be `345M`, `774M`, or `1558M`

prompts = [
  "This is a prompt 1",
  "This is a prompt 2",
  "This is a prompt 3",
  "This is a prompt 4"
]

text = gpt2.generate_batch_from_prompts(prompts) # returns an array of generated text

4. Fine-tuning GPT-2 to custom datasets

from gpt2_client import GPT2Client

gpt2 = GPT2Client('117M') # This could also be `345M`, `774M`, or `1558M`

my_corpus = './data/shakespeare.txt' # path to corpus
custom_text = gpt2.finetune(my_corpus, return_text=True) # Load your custom dataset

In order to fine-tune GPT-2 to your custom corpus or dataset, it's ideal to have a GPU or TPU at hand. Google Colab is one such tool you can make use of to re-train/fine-tune your custom model.

5. Encoding and decoding text sequences

from gpt2_client import GPT2Client

gpt2 = GPT2Client('117M') # This could also be `345M`, `774M`, or `1558M`
gpt2.load_model()

# encoding a sentence
encs = gpt2.encode_seq("Hello world, this is a sentence")
# [15496, 995, 11, 428, 318, 257, 6827]

# decoding an encoded sequence
decs = gpt2.decode_seq(encs)
# Hello world, this is a sentence

Contributing

Suggestions, improvements, and enhancements are always welcome! If you have any issues, please do raise one in the Issues section. If you have an improvement, do file an issue to discuss the suggestion before creating a PR.

All ideas – no matter how outrageous – welcome!

Licence

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gpt2_client-2.1.5.tar.gz (14.3 kB view details)

Uploaded Source

File details

Details for the file gpt2_client-2.1.5.tar.gz.

File metadata

  • Download URL: gpt2_client-2.1.5.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.31.0 CPython/3.7.3

File hashes

Hashes for gpt2_client-2.1.5.tar.gz
Algorithm Hash digest
SHA256 4e3f4877c5694b7f56eb5286b4872d8e33b837c71493e7b5c99d5f5b90cc648b
MD5 3072c78c0fa4d5af2ec00ee29ee7f0f2
BLAKE2b-256 d1f5b3af38978b1796036f181e6ebe879cd2420a196e675b86006cabf6f3253c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page