Skip to main content

Train QA model with feedback using Optuna and SimpleTransformers

Project description

QA Trainer 📚🤖

QA Trainer is a Python package for fine-tuning Question Answering models using feedback data.
It uses SimpleTransformers for model training and Optuna for hyperparameter tuning.



📂 Example Dataset Format

The dataset must be a JSON array of feedback entries. Each entry should have:

  • "question" → The question text
  • "answer" → The answer text (used as both context and answer in training)
  • "feedback" → A positive number to include it in training, negative/zero to exclude

Example feedback_data.json for question answer model but format will remain same for other model also you can take help of simple transformers python libary

[
    {
        "question": "What is AI?",
        "answer": "AI stands for Artificial Intelligence.",
        "feedback": 1
    },
    {
        "question": "Who is the president of Mars?",
        "answer": "Elon Musk.",
        "feedback": -1
    },
    {
        "question": "What is Python?",
        "answer": "Python is a programming language.",
        "feedback": 1
    }
]

🚀 Quick Start

from feedback_rl_transformers import train_with_feedback

train_with_feedback(
    feedback_data_path="feedback_data.json",  # Path to your feedback dataset
    model_type="bert",                        # Model type (see table below)
    model_path="bert-base-uncased",           # Pretrained model name or path
    output_dir="fine_tuned_model",            # Where to save the final model
    n_trials=5,                               # Optuna hyperparameter tuning trials
    use_cuda=False                            # Set True to use GPU if available
)

🧠 Supported Models

Model Name Code
ALBERT albert
BERT bert
BERTweet bertweet
BigBird* bigbird
CamemBERT camembert
DeBERTa* deberta
DistilBERT distilbert
ELECTRA electra
FlauBERT flaubert
HerBERT herbert
LayoutLM layoutlm
LayoutLMv2 layoutlmv2
Longformer* longformer
MPNet* mpnet
MobileBERT mobilebert
RemBERT rembert
RoBERTa roberta
SqueezeBert* squeezebert
XLM xlm
XLM-RoBERTa xlmroberta
XLNet xlnet

*Some large models may require more memory or a GPU.


⚙️ Parameters

Parameter Type Description
feedback_data_path str Path to feedback JSON dataset
model_type str Model type code (see table)
model_path str Pretrained model name or path
output_dir str Directory where fine-tuned model will be saved
n_trials int Number of Optuna trials for hyperparameter tuning
use_cuda bool Use GPU if available

📤 Output

After training:

  • The fine-tuned model will be saved in output_dir
  • Best hyperparameters from Optuna will be printed

Project details


Release history Release notifications | RSS feed

This version

0.1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

feedback_rl_transformers-0.1.tar.gz (4.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

feedback_rl_transformers-0.1-py3-none-any.whl (4.8 kB view details)

Uploaded Python 3

File details

Details for the file feedback_rl_transformers-0.1.tar.gz.

File metadata

  • Download URL: feedback_rl_transformers-0.1.tar.gz
  • Upload date:
  • Size: 4.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.6

File hashes

Hashes for feedback_rl_transformers-0.1.tar.gz
Algorithm Hash digest
SHA256 9551eec6dac9d2fdd53082dd6ff9376426ec0295088916c4d6c37de85e7c5f6a
MD5 4d6c4f7605559d2932c0efa57bb28eae
BLAKE2b-256 70dea9e454c985e26aa4f950886f8970015c49d3732524a2465106dc19bc2c6b

See more details on using hashes here.

File details

Details for the file feedback_rl_transformers-0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for feedback_rl_transformers-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 25e0ea37ce37ca513a505d8a01d903dd5ee22fd5b1ec0c28eaccc4e44e29a787
MD5 61e82b8a48a90cd4b7d05e70dc4679b7
BLAKE2b-256 50661936fd5eb37e7499731ea177dffe312543bde136cca4df2c6ea483917b39

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page