Skip to main content

A tool for question generation based on LMs

Project description

How to use

Our codes provide the ability for Question Generation, please follow these steps to train your QG models.

Enviroment

run pip install -r requirements.txt to install the required packages.

Train QG Models

  1. Prepare data

    You can use the hotpotQA dataset ./data/hotpotqa and squad1.1 dataset ./data/squad1.1 we provided, or use your own dataset which provides passages, references, and answers.

    If you use the datasets we provided, run python process.py to process the data, ready for your training.

    # For example: train T5 based on HotpotQA
    save_dir = 'directory to save processed data'
    hot = T5Hotpot()
    hot.deal_hotpot(save_dir)
    
  2. Train your own QG models

    We provide codes to train T5/FLAN-T5/BART-based QG models. You may update training arguments in ./args, and run the code file for specific model to train.

    # model name: code file name
    BART-base/large_finetune: bart.py
    T5-base/large_finetune: T5.py
    Flan-T5-base/large_finetune: Flan-T5.py
    Flan-T5-xl/xxl_LoRA: Flan-T5-lora.py
    

    For example, train T5-base_finetune QG model based on HotpotQA

    • Update the arg_file ./args/t5_base_hotpot.json

      {
          "model_name_or_path": "google-t5/t5-base",
          "tokenizer_name_or_path": "google-t5/t5-base",
          "max_len": 512,
          "target_max_len": 163,
          "output_dir": "./model/t5-base/hotpotqa",
          "overwrite_output_dir": true,
          "per_device_train_batch_size": 16,
          "per_device_eval_batch_size": 16,
          "gradient_accumulation_steps": 8,
          "learning_rate": 0.0001,
          "warmup_steps": 500,
          "weight_decay": 0.01,
          "num_train_epochs": 10,
          "seed": 42,
          "do_train": true,
          "do_eval": true,
          "logging_steps": 100,
          "logging_dir": "./model/t5-base/logs",
          "train_file_path": "./data/data_t5/hotpotqa/pt/train.pt",
          "valid_file_path": "./data/data_t5/hotpotqa/pt/dev.pt",
          "remove_unused_columns": false,
          "prediction_loss_only": true,
          "load_best_model_at_end": true,
          "evaluation_strategy": "epoch",
          "save_strategy": "epoch"
      }
      
    • Run code in T5.py to train QG model

      # train model
      arg_path = './args/t5_base_hotpot.json'
      data_dir = 'directory of your data'    # processed data
      train_path = data_dir + 'train.json'
      dev_path = data_dir + 'dev.json'
      test_path = data_dir + 'test.json'   # optional
      train_qg(arg_path, train_path, dev_path, test_path)
      

Generate questions using QG models

We provide different types of QG models for question generation.

  1. Use fine-tuned models in lmqg

    run python apply_lmqg.py to get the questions generated by models in lmqg.

    model_name = 'lmqg/flan-t5-base-squad-qg'
    data_json_path='./data/squad-test.json'
    save_path='./result/squad-test.csv'
    questions = apply_lmqg(
        model_name=model_name,
        data_json_path=data_json_path,
        save_path=save_path,
        batch_size=8
    )
    
  2. Use fine-tuned (including LoRA) models we provided for QG

    Run specific code file for each type of model. For example, to use T5-base_finetune model to generate questions on HotpotQA Test Set, run code in T5.py.

    # test
    arg_path = './args/t5_base_hotpot.json'
    test_pt_path = './data/data_t5/hotpotqa/pt/test.pt'    # tokenized test data
    model_dir = './model/t5-base/hotpotqa/'
    tokenizer_dir = model_dir + 'tokenizer/'
    result_save_path = './result/t5-base-finetune/hotpotqa/prediction.csv'
    decoded_texts = predict(model_dir, tokenizer_dir, test_pt_path, result_save_path)
    
  3. Use few-shot mode

    Run specific code file for each type of model. For example, to use Flan-T5-xl_fewshot model to generate questions on HotpotQA Test Set, run code in Flan-T5.py.

    # few-shot prediction
    model = 'google/flan-t5-large'
    few_shot_path = './data/squad1.1-few.json' # examples used for few-shot learning
    test_path = './data/squad-dev.json'
    result_save_path = './result/{}/prediction-few.csv'.format(model)
    dataset_name = 'hotpotqa'
    predict_few(model, few_shot_path, test_path, result_save_path, dataset_name)
    
  4. Use GPTs

    Run code in generate_gpt.py

    data_path = './test.xlsx'    # your data for QG
    model = 'gpt-3.5-turbo'
    save_dir = './result/'
    few_shot_path = './data/hotpotqa-few.json'
    question_generation(
        data_path=data_path,
        model=model,
        save_dir=save_dir,
        few_shot_path=few_shot_path
    )
    

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

QGEval_qg-1.0.6.tar.gz (18.1 kB view details)

Uploaded Source

Built Distribution

QGEval_qg-1.0.6-py3-none-any.whl (28.5 kB view details)

Uploaded Python 3

File details

Details for the file QGEval_qg-1.0.6.tar.gz.

File metadata

  • Download URL: QGEval_qg-1.0.6.tar.gz
  • Upload date:
  • Size: 18.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.10

File hashes

Hashes for QGEval_qg-1.0.6.tar.gz
Algorithm Hash digest
SHA256 8baa9e41bed4df6852277d6fbcdc24ba8ed271b0e0772ef0766e214edfc1c93b
MD5 7c0deeeb1b9533abe0e944ba01cf94db
BLAKE2b-256 b6aaf052241c66b940f8913f58d007f76aa97f147b25184c1593df2cb55d46f1

See more details on using hashes here.

File details

Details for the file QGEval_qg-1.0.6-py3-none-any.whl.

File metadata

  • Download URL: QGEval_qg-1.0.6-py3-none-any.whl
  • Upload date:
  • Size: 28.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.10

File hashes

Hashes for QGEval_qg-1.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 ca12a310ebb0b35511e80aed7ad4d282ed94f94efd062267f268103a9ec20cb8
MD5 4b949413472d6ab86e04bb5b5cc8534a
BLAKE2b-256 1bd240bc51abf44bd5771c4af439c40c99dcd887d36ba44870659da7d6646bf4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page