High level interface for text applications using PyTroch RNN's.
GeNN (generative neural networks) is a high-level interface for text applications using PyTorch RNN's.
- Parsing txt, json, and csv files.
- NLTK, regex and spacy tokenization support.
- GloVe and fastText pretrained embeddings, with the ability to fine-tune for your data.
- Architectures and customization:
- GPT-2 with small, medium, and large variants.
- LSTM and GRU, with variable size.
- Variable number of layers and batches.
- Text generation:
- Random seed sampling from the n first tokens in all instances, or the most frequent token.
- Top-K sampling for next token prediction with variable K.
- Nucleus sampling for next token prediction with variable probability threshold.
How to install
pip install genn
- PyTorch 1.4.0
pip install torch==1.4.0
- Pytorch Transformers
pip install pytorch_transformers
pip install numpy
pip install fasttext
Use the package manager pip to install genn.
from genn import Preprocessing, LSTMGenerator, GPT2 #LSTM example ds = Preprocessing("data.txt") gen = LSTMGenerator(ds, nLayers = 2, batchSize = 16, embSize = 64, lstmSize = 16, epochs = 20) #Train the model gen.run() # Generate 5 new documents print(gen.generate_document(5)) #GPT-2 example gen = GPT2("data.txt", taskToken = "Movie:", epochs = 7, variant = "medium") #Train the model gen.run() #Generate 10 new documents print(gen.generate_document(10))
For more examples on how to use Preprocessing, please refer to this file.
For more examples on how to use LSTMGenerator and GRUGenerator, please refer to this file.
For more examples on how to use GPT2, please refer to this file
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Distributed under the MIT License. See LICENSE for more information.
Release history Release notifications | RSS feed
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
|Filename, size||File type||Python version||Upload date||Hashes|
|Filename, size abdoTheBest-0.9.1-py3-none-any.whl (17.4 kB)||File type Wheel||Python version py3||Upload date||Hashes View|
Hashes for abdoTheBest-0.9.1-py3-none-any.whl