Skip to main content

Generative Summarization for Data Augmentation

Project description

gensum - Generative Summarization for Data Augmentation

License PyPI version Python 3.10

Introduction

Imbalanced class distribution remains a classic common problem in ML. Undersampling combined with oversampling are two methods of attempting to address this issue. Techniques such as SMOTE and MLSMOTE have been proposed, but the high dimensional nature of numerical vectors created from text makes other data augmentation approaches preferable.

gensum is an NLP library based on absum that uses generative summarization to perform data augmentation by oversampling under-represented classes in text classification datasets. Recent advancements in generative models such as ChatGPT make this approach optimal in achieving realistic but unique data for the augmentation process.

It uses ChatGPT by default, but is designed in a modular way to allow you to use any large language models capable of generative summarization. gensum is format agnostic, expecting only a DataFrame containing a text and classifier column.

Installation

Via pip

pip install gensum

From source

git clone https://github.com/aaronbriel/gensum.git
pip install [--editable] .

or

pip install git+https://github.com/aaronbriel/gensum.git

Usage

gensum expects a DataFrame containing a text column which defaults to 'text', and another classifier column which defaults to 'classifier'. All available parameters are detailed in the Parameters section below. Be sure to set the OPENAI_API_KEY environmental parameter prior to running the code.

import pandas as pd
from gensum import Augmentor

csv = 'path_to_csv'
df = pd.read_csv(csv)
augmentor = Augmentor(df, text_column='text', classifier='intent')
df_augmented = augmentor.gen_sum_augment()
# Store resulting dataframe as a csv
df_augmented.to_csv(csv.replace('.csv', '-augmented.csv'), encoding='utf-8', index=False)

NOTE: The output dataframe contains only the augmented rows.

Parameters

Name Type Description
df (:class:pandas.Dataframe, required, defaults to None) Dataframe containing text and one-hot encoded features.
text_column (:obj:string, optional, defaults to "text") Column in df containing text.
classifier (:obj:string, optional, defaults to "classifier") Classifier to augment data for.
classifier_values (:obj:string, optional, defaults to None) Specific classifier values to augment data for.
min_length (:obj:int, optional, defaults to 10) The min length of the sequence to be generated. Between 0 and infinity. Default to 10.
max_length (:obj:int, optional, defaults to 50) The max length of the sequence to be generated. Between min_length and infinity. Default to 50.
num_samples (:obj:int, optional, defaults to 20) Number of samples to pull from dataframe with specific feature to use in generating new sample with Generative Summarization.
threshold (:obj:int, optional, defaults to mean count for all classifier values) Maximum ceiling for each feature, normally the under-sample max.
prompt (:obj:string, optional, defaults to "Create SUMMARY_COUNT unique, informally written sentences similar to the ones listed here:") The prompt to use for the generative summarization. If you change the prompt, please be sure to keep the SUMMARY_COUNT string in it somewhere as this is expected and replaced based on the append count calculated for said classifier value.
llm (:obj:string, optional, defaults to 'chatgpt') The generative LLM to use for summarization.
model (:obj:string, optional, defaults to 'gpt-3.5-turbo') The specific model to use.
temperature (:obj:int, optional, defaults to 0) Determines the randomness of the generated sequences. Between 0 and 1, where a higher value means the generated sequences will be more random.
debug (:obj:bool, optional, defaults to True) If set, prints generated summarizations.

Citation

Please reference this library if you use this work in a published or open-source project.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gensum-0.1.5.tar.gz (10.8 kB view details)

Uploaded Source

Built Distribution

gensum-0.1.5-py3-none-any.whl (13.0 kB view details)

Uploaded Python 3

File details

Details for the file gensum-0.1.5.tar.gz.

File metadata

  • Download URL: gensum-0.1.5.tar.gz
  • Upload date:
  • Size: 10.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.4

File hashes

Hashes for gensum-0.1.5.tar.gz
Algorithm Hash digest
SHA256 6448c07813f988b3290b78c79e76eaf28978d8971bfc291d035317c7be1f42d1
MD5 a7beadeb50d6a09db961a43837fac075
BLAKE2b-256 0181f4755ffdb4722e901148e37ea0150084b8ab31b9188d9f9837a3535476d6

See more details on using hashes here.

File details

Details for the file gensum-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: gensum-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 13.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.4

File hashes

Hashes for gensum-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 4f473ff04e334fd0fa45e59ee4db07f7717aa323b145fccdc40bad8da1824fc2
MD5 b95f20eee3c3982000e6472ca838d30c
BLAKE2b-256 8b7861fd8ba91c5ec8c3d1058fcc625d6a9fb7903ce2c75927cbf1acf420c926

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page