Skip to main content

A tool for question generation based on LMs

Project description

How to use

Our codes provide the ability for Question Generation, please follow these steps to train your QG models.

Enviroment

run pip install -r requirements.txt to install the required packages.

Train QG Models

  1. Prepare data

    You can use the hotpotQA dataset ./data/hotpotqa and squad1.1 dataset ./data/squad1.1 we provided, or use your own dataset which provides passages, references, and answers.

    If you use the datasets we provided, run python process.py to process the data, ready for your training.

    # For example: train T5 based on HotpotQA
    save_dir = 'directory to save processed data'
    hot = T5Hotpot()
    hot.deal_hotpot(save_dir)
    
  2. Train your own QG models

    We provide codes to train T5/FLAN-T5/BART-based QG models. You may update training arguments in ./args, and run the code file for specific model to train.

    # model name: code file name
    BART-base/large_finetune: bart.py
    T5-base/large_finetune: T5.py
    Flan-T5-base/large_finetune: Flan-T5.py
    Flan-T5-xl/xxl_LoRA: Flan-T5-lora.py
    

    For example, train T5-base_finetune QG model based on HotpotQA

    • Update the arg_file ./args/t5_base_hotpot.json

      {
          "model_name_or_path": "google-t5/t5-base",
          "tokenizer_name_or_path": "google-t5/t5-base",
          "max_len": 512,
          "target_max_len": 163,
          "output_dir": "./model/t5-base/hotpotqa",
          "overwrite_output_dir": true,
          "per_device_train_batch_size": 16,
          "per_device_eval_batch_size": 16,
          "gradient_accumulation_steps": 8,
          "learning_rate": 0.0001,
          "warmup_steps": 500,
          "weight_decay": 0.01,
          "num_train_epochs": 10,
          "seed": 42,
          "do_train": true,
          "do_eval": true,
          "logging_steps": 100,
          "logging_dir": "./model/t5-base/logs",
          "train_file_path": "./data/data_t5/hotpotqa/pt/train.pt",
          "valid_file_path": "./data/data_t5/hotpotqa/pt/dev.pt",
          "remove_unused_columns": false,
          "prediction_loss_only": true,
          "load_best_model_at_end": true,
          "evaluation_strategy": "epoch",
          "save_strategy": "epoch"
      }
      
    • Run code in T5.py to train QG model

      # train model
      arg_path = './args/t5_base_hotpot.json'
      data_dir = 'directory of your data'    # processed data
      train_path = data_dir + 'train.json'
      dev_path = data_dir + 'dev.json'
      test_path = data_dir + 'test.json'   # optional
      train_qg(arg_path, train_path, dev_path, test_path)
      

Generate questions using QG models

We provide different types of QG models for question generation.

  1. Use fine-tuned models in lmqg

    run python apply_lmqg.py to get the questions generated by models in lmqg.

    model_name = 'lmqg/flan-t5-base-squad-qg'
    data_json_path='./data/squad-test.json'
    save_path='./result/squad-test.csv'
    questions = apply_lmqg(
        model_name=model_name,
        data_json_path=data_json_path,
        save_path=save_path,
        batch_size=8
    )
    
  2. Use fine-tuned (including LoRA) models we provided for QG

    Run specific code file for each type of model. For example, to use T5-base_finetune model to generate questions on HotpotQA Test Set, run code in T5.py.

    # test
    arg_path = './args/t5_base_hotpot.json'
    test_pt_path = './data/data_t5/hotpotqa/pt/test.pt'    # tokenized test data
    model_dir = './model/t5-base/hotpotqa/'
    tokenizer_dir = model_dir + 'tokenizer/'
    result_save_path = './result/t5-base-finetune/hotpotqa/prediction.csv'
    decoded_texts = predict(model_dir, tokenizer_dir, test_pt_path, result_save_path)
    
  3. Use few-shot mode

    Run specific code file for each type of model. For example, to use Flan-T5-xl_fewshot model to generate questions on HotpotQA Test Set, run code in Flan-T5.py.

    # few-shot prediction
    model = 'google/flan-t5-large'
    few_shot_path = './data/squad1.1-few.json' # examples used for few-shot learning
    test_path = './data/squad-dev.json'
    result_save_path = './result/{}/prediction-few.csv'.format(model)
    dataset_name = 'hotpotqa'
    predict_few(model, few_shot_path, test_path, result_save_path, dataset_name)
    
  4. Use GPTs

    Run code in generate_gpt.py

    data_path = './test.xlsx'    # your data for QG
    model = 'gpt-3.5-turbo'
    save_dir = './result/'
    few_shot_path = './data/hotpotqa-few.json'
    question_generation(
        data_path=data_path,
        model=model,
        save_dir=save_dir,
        few_shot_path=few_shot_path
    )
    

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

QGEval_qg-1.0.7.tar.gz (17.9 kB view details)

Uploaded Source

Built Distribution

QGEval_qg-1.0.7-py3-none-any.whl (28.2 kB view details)

Uploaded Python 3

File details

Details for the file QGEval_qg-1.0.7.tar.gz.

File metadata

  • Download URL: QGEval_qg-1.0.7.tar.gz
  • Upload date:
  • Size: 17.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.10

File hashes

Hashes for QGEval_qg-1.0.7.tar.gz
Algorithm Hash digest
SHA256 de89830d00bd5ec4baed162ad4e182047873aabbdc0f94c6f93df13e5aa3dabf
MD5 a9fe0d067efd462eaae5861442a5909d
BLAKE2b-256 90045b84d3635a550c867df88b7f1971206dad059fdc7eb547f4eaba9e34c90b

See more details on using hashes here.

File details

Details for the file QGEval_qg-1.0.7-py3-none-any.whl.

File metadata

  • Download URL: QGEval_qg-1.0.7-py3-none-any.whl
  • Upload date:
  • Size: 28.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.10

File hashes

Hashes for QGEval_qg-1.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 1cee9b98a42483c8797ce55d1b8ffadc6573895565d36ce65f6ec6b0a453c61f
MD5 ee2b0fc9f96051e65bc96144b003fa97
BLAKE2b-256 110ebad2dc74c5e42024621346a46be08467819b626380d7a9c946d64ac34f26

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page