Skip to main content

A tool for question generation based on LMs

Project description

How to use

Our codes provide the ability for Question Generation, please follow these steps to train your QG models.

Enviroment

run pip install -r requirements.txt to install the required packages.

Train QG Models

  1. Prepare data

    You can use the hotpotQA dataset ./data/hotpotqa and squad1.1 dataset ./data/squad1.1 we provided, or use your own dataset which provides passages, references, and answers.

    If you use the datasets we provided, run python process.py to process the data, ready for your training.

    # For example: train T5 based on HotpotQA
    save_dir = 'directory to save processed data'
    hot = T5Hotpot()
    hot.deal_hotpot(save_dir)
    
  2. Train your own QG models

    We provide codes to train T5/FLAN-T5/BART-based QG models. You may update training arguments in ./args, and run the code file for specific model to train.

    # model name: code file name
    BART-base/large_finetune: bart.py
    T5-base/large_finetune: T5.py
    Flan-T5-base/large_finetune: Flan-T5.py
    Flan-T5-xl/xxl_LoRA: Flan-T5-lora.py
    

    For example, train T5-base_finetune QG model based on HotpotQA

    • Update the arg_file ./args/t5_base_hotpot.json

      {
          "model_name_or_path": "google-t5/t5-base",
          "tokenizer_name_or_path": "google-t5/t5-base",
          "max_len": 512,
          "target_max_len": 163,
          "output_dir": "./model/t5-base/hotpotqa",
          "overwrite_output_dir": true,
          "per_device_train_batch_size": 16,
          "per_device_eval_batch_size": 16,
          "gradient_accumulation_steps": 8,
          "learning_rate": 0.0001,
          "warmup_steps": 500,
          "weight_decay": 0.01,
          "num_train_epochs": 10,
          "seed": 42,
          "do_train": true,
          "do_eval": true,
          "logging_steps": 100,
          "logging_dir": "./model/t5-base/logs",
          "train_file_path": "./data/data_t5/hotpotqa/pt/train.pt",
          "valid_file_path": "./data/data_t5/hotpotqa/pt/dev.pt",
          "remove_unused_columns": false,
          "prediction_loss_only": true,
          "load_best_model_at_end": true,
          "evaluation_strategy": "epoch",
          "save_strategy": "epoch"
      }
      
    • Run code in T5.py to train QG model

      # train model
      arg_path = './args/t5_base_hotpot.json'
      data_dir = 'directory of your data'    # processed data
      train_path = data_dir + 'train.json'
      dev_path = data_dir + 'dev.json'
      test_path = data_dir + 'test.json'   # optional
      train_qg(arg_path, train_path, dev_path, test_path)
      

Generate questions using QG models

We provide different types of QG models for question generation.

  1. Use fine-tuned models in lmqg

    run python apply_lmqg.py to get the questions generated by models in lmqg.

    model_name = 'lmqg/flan-t5-base-squad-qg'
    data_json_path='./data/squad-test.json'
    save_path='./result/squad-test.csv'
    questions = apply_lmqg(
        model_name=model_name,
        data_json_path=data_json_path,
        save_path=save_path,
        batch_size=8
    )
    
  2. Use fine-tuned (including LoRA) models we provided for QG

    Run specific code file for each type of model. For example, to use T5-base_finetune model to generate questions on HotpotQA Test Set, run code in T5.py.

    # test
    arg_path = './args/t5_base_hotpot.json'
    test_pt_path = './data/data_t5/hotpotqa/pt/test.pt'    # tokenized test data
    model_dir = './model/t5-base/hotpotqa/'
    tokenizer_dir = model_dir + 'tokenizer/'
    result_save_path = './result/t5-base-finetune/hotpotqa/prediction.csv'
    decoded_texts = predict(model_dir, tokenizer_dir, test_pt_path, result_save_path)
    
  3. Use few-shot mode

    Run specific code file for each type of model. For example, to use Flan-T5-xl_fewshot model to generate questions on HotpotQA Test Set, run code in Flan-T5.py.

    # few-shot prediction
    model = 'google/flan-t5-large'
    few_shot_path = './data/squad1.1-few.json' # examples used for few-shot learning
    test_path = './data/squad-dev.json'
    result_save_path = './result/{}/prediction-few.csv'.format(model)
    dataset_name = 'hotpotqa'
    predict_few(model, few_shot_path, test_path, result_save_path, dataset_name)
    
  4. Use GPTs

    Run code in generate_gpt.py

    data_path = './test.xlsx'    # your data for QG
    model = 'gpt-3.5-turbo'
    save_dir = './result/'
    few_shot_path = './data/hotpotqa-few.json'
    question_generation(
        data_path=data_path,
        model=model,
        save_dir=save_dir,
        few_shot_path=few_shot_path
    )
    

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

QGEval_qg-1.0.5.tar.gz (548.2 kB view details)

Uploaded Source

Built Distribution

QGEval_qg-1.0.5-py3-none-any.whl (569.4 kB view details)

Uploaded Python 3

File details

Details for the file QGEval_qg-1.0.5.tar.gz.

File metadata

  • Download URL: QGEval_qg-1.0.5.tar.gz
  • Upload date:
  • Size: 548.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.10

File hashes

Hashes for QGEval_qg-1.0.5.tar.gz
Algorithm Hash digest
SHA256 1645d4359dcef1101ad43cb30358f546e84004c0c8837baa0ed530f6d3ec5de8
MD5 2e32d5059606e12dc8f422b58cc2ffaa
BLAKE2b-256 c95748a30a3096953b8a37b7fc646b033c236f02841d9ed9e4163573456520e8

See more details on using hashes here.

File details

Details for the file QGEval_qg-1.0.5-py3-none-any.whl.

File metadata

  • Download URL: QGEval_qg-1.0.5-py3-none-any.whl
  • Upload date:
  • Size: 569.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.10

File hashes

Hashes for QGEval_qg-1.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 fba9f12d1f430dfa768407b372d13f5ad0c188480cebca8d4ad5eb9df62a2a4a
MD5 77819602cb8440ffb13da22981d383b6
BLAKE2b-256 d8e21434e979e4298b4554df00f427709d7bb22d09e51e1c52c40b66f8bab62e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page