Skip to main content

A tool for question generation based on LMs

Project description

How to use

Our codes provide the ability for Question Generation, please follow these steps to train your QG models.

Enviroment

run pip install -r requirements.txt to install the required packages.

Train QG Models

  1. Prepare data

    You can use the hotpotQA dataset ./data/hotpotqa and squad1.1 dataset ./data/squad1.1 we provided, or use your own dataset which provides passages, references, and answers.

    If you use the datasets we provided, run python process.py to process the data, ready for your training.

    # For example: train T5 based on HotpotQA
    save_dir = 'directory to save processed data'
    hot = T5Hotpot()
    hot.deal_hotpot(save_dir)
    
  2. Train your own QG models

    We provide codes to train T5/FLAN-T5/BART-based QG models. You may update training arguments in ./args, and run the code file for specific model to train.

    # model name: code file name
    BART-base/large_finetune: bart.py
    T5-base/large_finetune: T5.py
    Flan-T5-base/large_finetune: Flan-T5.py
    Flan-T5-xl/xxl_LoRA: Flan-T5-lora.py
    

    For example, train T5-base_finetune QG model based on HotpotQA

    • Update the arg_file ./args/t5_base_hotpot.json

      {
          "model_name_or_path": "google-t5/t5-base",
          "tokenizer_name_or_path": "google-t5/t5-base",
          "max_len": 512,
          "target_max_len": 163,
          "output_dir": "./model/t5-base/hotpotqa",
          "overwrite_output_dir": true,
          "per_device_train_batch_size": 16,
          "per_device_eval_batch_size": 16,
          "gradient_accumulation_steps": 8,
          "learning_rate": 0.0001,
          "warmup_steps": 500,
          "weight_decay": 0.01,
          "num_train_epochs": 10,
          "seed": 42,
          "do_train": true,
          "do_eval": true,
          "logging_steps": 100,
          "logging_dir": "./model/t5-base/logs",
          "train_file_path": "./data/data_t5/hotpotqa/pt/train.pt",
          "valid_file_path": "./data/data_t5/hotpotqa/pt/dev.pt",
          "remove_unused_columns": false,
          "prediction_loss_only": true,
          "load_best_model_at_end": true,
          "evaluation_strategy": "epoch",
          "save_strategy": "epoch"
      }
      
    • Run code in T5.py to train QG model

      # train model
      arg_path = './args/t5_base_hotpot.json'
      data_dir = 'directory of your data'    # processed data
      train_path = data_dir + 'train.json'
      dev_path = data_dir + 'dev.json'
      test_path = data_dir + 'test.json'   # optional
      train_qg(arg_path, train_path, dev_path, test_path)
      

Generate questions using QG models

We provide different types of QG models for question generation.

  1. Use fine-tuned models in lmqg

    run python apply_lmqg.py to get the questions generated by models in lmqg.

    model_name = 'lmqg/flan-t5-base-squad-qg'
    data_json_path='./data/squad-test.json'
    save_path='./result/squad-test.csv'
    questions = apply_lmqg(
        model_name=model_name,
        data_json_path=data_json_path,
        save_path=save_path,
        batch_size=8
    )
    
  2. Use fine-tuned (including LoRA) models we provided for QG

    Run specific code file for each type of model. For example, to use T5-base_finetune model to generate questions on HotpotQA Test Set, run code in T5.py.

    # test
    arg_path = './args/t5_base_hotpot.json'
    test_pt_path = './data/data_t5/hotpotqa/pt/test.pt'    # tokenized test data
    model_dir = './model/t5-base/hotpotqa/'
    tokenizer_dir = model_dir + 'tokenizer/'
    result_save_path = './result/t5-base-finetune/hotpotqa/prediction.csv'
    decoded_texts = predict(model_dir, tokenizer_dir, test_pt_path, result_save_path)
    
  3. Use few-shot mode

    Run specific code file for each type of model. For example, to use Flan-T5-xl_fewshot model to generate questions on HotpotQA Test Set, run code in Flan-T5.py.

    # few-shot prediction
    model = 'google/flan-t5-large'
    few_shot_path = './data/squad1.1-few.json' # examples used for few-shot learning
    test_path = './data/squad-dev.json'
    result_save_path = './result/{}/prediction-few.csv'.format(model)
    dataset_name = 'hotpotqa'
    predict_few(model, few_shot_path, test_path, result_save_path, dataset_name)
    
  4. Use GPTs

    Run code in generate_gpt.py

    data_path = './test.xlsx'    # your data for QG
    model = 'gpt-3.5-turbo'
    save_dir = './result/'
    few_shot_path = './data/hotpotqa-few.json'
    question_generation(
        data_path=data_path,
        model=model,
        save_dir=save_dir,
        few_shot_path=few_shot_path
    )
    

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

QGEval_qg-1.0.11.tar.gz (18.0 kB view details)

Uploaded Source

Built Distribution

QGEval_qg-1.0.11-py3-none-any.whl (28.3 kB view details)

Uploaded Python 3

File details

Details for the file QGEval_qg-1.0.11.tar.gz.

File metadata

  • Download URL: QGEval_qg-1.0.11.tar.gz
  • Upload date:
  • Size: 18.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.10

File hashes

Hashes for QGEval_qg-1.0.11.tar.gz
Algorithm Hash digest
SHA256 b76604dd7ccb6dc048eefc0a3fb3456bcbab285d9d86169cc95f22fd52dbfeaf
MD5 cb81908031a409b2a6c6931229a639e1
BLAKE2b-256 ce678717e1d957b213f73e1178366256be9fce8e6438983c6ba8037f9f750d6d

See more details on using hashes here.

File details

Details for the file QGEval_qg-1.0.11-py3-none-any.whl.

File metadata

  • Download URL: QGEval_qg-1.0.11-py3-none-any.whl
  • Upload date:
  • Size: 28.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.10

File hashes

Hashes for QGEval_qg-1.0.11-py3-none-any.whl
Algorithm Hash digest
SHA256 56950d62a04c0a5355358825fab301b3eb2014114f4466307bae411c24dbe247
MD5 c5d744da4f5b59b77b258d583caad0b1
BLAKE2b-256 7d1eedfaee1353df8bcc7629eb59a1fa22516dfb6938042526bdad7f068e38e4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page