Skip to main content

This python package is designed to provide an adaptable framework for text classification. It can be used to manage and classify text across multiple topics using Large Language Models (LLMs).

Project description

general-classifier

general-classifier is a Python package designed for multi-topic text classification leveraging Large Language Models (LLMs). It allows users to define multiple classification topics, manage categories within each topic, classify text data using various language models, evaluate classification performance, and iteratively improve classification prompts.

Table of Contents

Features

  • Multi-topic Classification:
    Classify text across multiple defined topics simultaneously, each with its own set of categories.

  • Memory-efficient Batch Processing:
    Process large datasets in batches with automatic GPU memory management and model reloading between batches.

  • Dynamic Topic & Category Management:
    Easily add, remove, and manage multiple classification topics and their respective categories.

  • Flexible Model Integration:
    Supports integration with:

    • Local Transformers models (with torch)
    • OpenAI API models
    • DeepInfra hosted models
    • Support for both direct model output and guided/constrained prediction
  • Performance Evaluation:
    Comprehensive evaluation metrics including:

    • Per-topic accuracy
    • Confusion matrices
    • Micro precision, recall, and F1 scores
  • Iterative Prompt Improvement:
    Leverage LLMs to automatically refine classification prompts and improve accuracy over time.

  • Interactive Interface:
    Optional Jupyter widget-based interface for easy management of topics, categories, and classification tasks.

  • Conditional Classification:
    Support for dependent classifications with conditions based on previous results.

Installation

Ensure you have Python 3.7 or higher installed. Install the required dependencies using pip:

pip install torch transformers openai guidance ipywidgets

Requirements

  • Python 3.7+
  • PyTorch
  • Transformers
  • OpenAI Python client (for OpenAI API)
  • Guidance (for guided generation)
  • IPython/Jupyter (for interactive interface)

Quick Start

1. Define Topics and Categories

Begin by defining classification topics and their respective categories.

from general_classifier import gc

# Add a new topic with categories
gc.add_topic(
    topic_name="Car Brands",
    categories=["BMW", "Audi", "Mercedes"]
)

# Add another category to the existing topic
gc.add_category("A", categoryName="Toyota")

# Display all defined topics and their categories
gc.show_topics_and_categories()

2. Set Models

Configure the main classification model and optionally a separate model for prompt improvement.

# Set the main classification model (e.g., a local Transformers model)
gc.setModel(
    newModel="meta-llama/Llama-2-7b-chat-hf", 
    newModelType="Transformers",
    newInferenceType="transformers"  # Options: "transformers", "guidance"
)

# Optionally set a separate model for prompt improvement
gc.setPromptModel(
    newPromptModel="gpt-4", 
    newPromptModelType="OpenAI", 
    api_key="your-openai-api-key",
    newInferenceType="cloud"
)

3. Classify a Single Text

Classify a single piece of text across all defined topics.

text_to_classify = "The new BMW X5 has impressive features."
results, probabilities = gc.classify(
    text=text_to_classify,
    isItASingleClassification=True,  # Print results to console
    constrainedOutput=True  # Use constrained output mode
)

print(f"Classification results: {results}")
print(f"Confidence scores: {probabilities}")

4. Classify a Dataset

Classify text data from a CSV file and evaluate performance.

# Classify data from 'data.csv' with evaluation enabled
gc.classify_table(
    dataset="data",  # Filename without .csv extension
    withEvaluation=True,  # Compare against ground truth
    constrainedOutput=True,  # Use constrained output mode
    BATCH_SIZE=10  # Process 10 rows per batch
)

This will generate data_(result).csv containing classification results and performance metrics.

5. Evaluate Prompt Performance

Assess the performance of a specific topic's prompt on a dataset.

# Evaluate prompt accuracy for topic 'A' on dataset 'mydata.csv'
gc.check_prompt_performance_for_topic(
    topicId="A", 
    dataset="mydata", 
    constrainedOutput=True
)

6. Improve Prompts Iteratively

Enhance the classification prompt for a specific topic using LLM feedback.

# Iteratively improve prompt for topic 'A' using dataset 'mydata.csv'
gc.improve_prompt(
    topicId="A", 
    dataset="mydata", 
    constrainedOutput=True, 
    num_iterations=10
)

This function will refine the prompt over multiple iterations, seeking to improve classification accuracy.

7. Managing Topics and Categories

Functions to manage topics and categories:

# Update a topic's prompt
gc.setPrompt(topicId="A", newPrompt="New improved prompt for classification.")

# Remove a specific category from a topic
gc.remove_category(topicId="A", categoryId="a")

# Remove a specific topic
gc.remove_topic("A")

# Clear all topics
gc.removeAllTopics()

8. Saving and Loading Topics

Persist and retrieve your topic configurations:

# Save topics to a JSON file
gc.save_topics("my_classification_topics")

# Load topics from a JSON file
gc.load_topics("my_classification_topics")

Advanced Features

Conditional Classification

You can create dependencies between topics using conditions:

# Add a topic with a condition
condition_topic = gc.add_topic(
    topic_name="Car Features",
    categories=["Sport", "Luxury", "Economy"],
    condition="A==a"  # Only classify if topic A resulted in category a
)

Batch Processing

For large datasets, the classifier processes data in batches and automatically manages GPU memory:

gc.classify_table(
    dataset="large_dataset",
    withEvaluation=True,
    BATCH_SIZE=5  # Smaller batches for larger models
)

GPU Memory Management

The system includes built-in memory management for GPU-based models:

# These functions handle model loading and unloading between batches
# to prevent GPU memory issues
# They are automatically used in classify_table but can be called manually
model = gc.load_model()
# ... do some processing ...
gc.unload_model(model)

Interactive Interface

For a more user-friendly experience, you can use the interactive Jupyter interface:

# Launch the widget-based interface in a Jupyter notebook
gc.openInterface()

API Reference

Main Classification Functions

classify(text: str, isItASingleClassification: bool = True, constrainedOutput: bool = True, withEvaluation: bool = False, groundTruthRow: list = None) -> tuple

Classifies a piece of text across all defined topics.

  • Returns: Tuple of (predictions_list, probabilities_list)

classify_table(dataset: str, withEvaluation: bool = False, constrainedOutput: bool = True, BATCH_SIZE: int = 10)

Classifies each row in a CSV dataset with optional batch processing.

Model Management

setModel(newModel: str, newModelType: str, api_key: str = "", newInferenceType: str = "transformers")

Sets the main classification model.

setPromptModel(newPromptModel: str, newPromptModelType: str, api_key: str = "", newInferenceType: str = "guidance")

Sets the model used for prompt improvement.

load_model() -> object

Loads a fresh instance of the model to GPU.

unload_model(model_to_unload: object) -> None

Thoroughly unloads a model from GPU memory.

Topic Management

add_topic(topic_name: str, categories: list = [], condition: str = "", prompt: str = default_prompt) -> dict

Adds a new classification topic.

remove_topic(topic_id_str: str)

Removes a topic by its ID.

add_category(topicId: str, categoryName: str, Condition: str = "")

Adds a category to a specified topic.

remove_category(topicId: str, categoryId: str)

Removes a category from a specified topic.

setPrompt(topicId: str, newPrompt: str)

Updates the prompt for a specified topic.

removeAllTopics()

Removes all defined topics and resets related counters.

Persistence

save_topics(filename: str)

Saves all topics to a JSON file.

load_topics(filename: str)

Loads topics from a JSON file.

Prompt Improvement

check_prompt_performance_for_topic(topicId: str, dataset: str, constrainedOutput: bool = True, groundTruthCol: int = None)

Evaluates the performance of a specific topic's prompt.

improve_prompt(topicId: str, dataset: str, constrainedOutput: bool = True, groundTruthCol: int = None, num_iterations: int = 10)

Iteratively improves a prompt using LLM feedback.

getLLMImprovedPromptWithFeedback(old_prompt: str, old_accuracy: float, topic_info: dict) -> str

Gets an improved prompt suggestion from the LLM.

Interface

openInterface()

Opens an interactive widget-based interface in Jupyter.

Examples

Example 1: Medical Record Classification

# Set up model
gc.setModel("meta-llama/Llama-2-7b-chat-hf", "Transformers")

# Define medical topics
diagnosis = gc.add_topic(
    topic_name="Diagnosis",
    categories=["Positive", "Negative", "Inconclusive"],
    prompt="Classify the medical report diagnosis as [CATEGORIES]. Report: '[TEXT]'. The diagnosis is:"
)

# Classify a medical report
medical_text = "Patient shows no signs of infection. All tests negative."
results, _ = gc.classify(medical_text)
print(f"Diagnosis classification: {results[0]}")  # Expected: "Negative"

Example 2: Multi-stage Document Processing

# First level classification
doc_type = gc.add_topic(
    topic_name="Document Type",
    categories=["Invoice", "Contract", "Report"]
)

# Second level (dependent on first)
invoice_status = gc.add_topic(
    topic_name="Invoice Status",
    categories=["Paid", "Pending", "Overdue"],
    condition=f"{doc_type['id']}==a"  # Only if doc is Invoice
)

contract_type = gc.add_topic(
    topic_name="Contract Type",
    categories=["Employment", "Service", "NDA"],
    condition=f"{doc_type['id']}==b"  # Only if doc is Contract
)

# Process a batch of documents
gc.classify_table("documents", withEvaluation=True)

Contributing

Contributions are welcome! Please open an issue or submit a pull request for any enhancements, bug fixes, or new features.

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/YourFeature)
  3. Commit your changes (git commit -m 'Add some feature')
  4. Push to the branch (git push origin feature/YourFeature)
  5. Open a Pull Request

License

This project is licensed under the MIT License - see the LICENSE file for details.


Disclaimer

This tool is provided "as is" without any warranty. When using API-based models (OpenAI, DeepInfra), ensure you comply with their respective terms of service and usage policies.

Contact

For questions or support, please open an issue on the GitHub repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

general_classifier-0.1.10.tar.gz (20.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

general_classifier-0.1.10-py3-none-any.whl (21.1 kB view details)

Uploaded Python 3

File details

Details for the file general_classifier-0.1.10.tar.gz.

File metadata

  • Download URL: general_classifier-0.1.10.tar.gz
  • Upload date:
  • Size: 20.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.11

File hashes

Hashes for general_classifier-0.1.10.tar.gz
Algorithm Hash digest
SHA256 d934d4a141a2f14e775479903dfdd5b3b3f9afe42193e735dfe3d7d005ef72aa
MD5 be1e7dc37e370aa8c6e0a92bf1cb5000
BLAKE2b-256 8743cdef9487ba5211d9bf60d5f558fe66955554e0c370ace4e9fe60ee8ef6bb

See more details on using hashes here.

File details

Details for the file general_classifier-0.1.10-py3-none-any.whl.

File metadata

File hashes

Hashes for general_classifier-0.1.10-py3-none-any.whl
Algorithm Hash digest
SHA256 201cd767a3c146203f0786a2fc86ee74d26d2ec70ccaa59aea2bc766f82352c5
MD5 66ab868549cc371576f5a0051ad280bd
BLAKE2b-256 47c7e7032fc64b26b20007cf047aaefb1e8eeced05f87c37e3b1e39e43bfea3b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page