A tool for question generation based on LMs
Project description
How to use
Our codes provide the ability for Question Generation
, please follow these steps to train your QG models.
Enviroment
run pip install -r requirements.txt
to install the required packages.
Train QG Models
-
Prepare data
You can use the hotpotQA dataset ./data/hotpotqa and squad1.1 dataset ./data/squad1.1 we provided, or use your own dataset which provides passages, references, and answers.
If you use the datasets we provided, run
python process.py
to process the data, ready for your training.# For example: train T5 based on HotpotQA save_dir = 'directory to save processed data' hot = T5Hotpot() hot.deal_hotpot(save_dir)
-
Train your own QG models
We provide codes to train
T5/FLAN-T5/BART
-based QG models. You may update training arguments in ./args, and run the code file for specific model to train.# model name: code file name BART-base/large_finetune: bart.py T5-base/large_finetune: T5.py Flan-T5-base/large_finetune: Flan-T5.py Flan-T5-xl/xxl_LoRA: Flan-T5-lora.py
For example,
train T5-base_finetune QG model based on HotpotQA
-
Update the arg_file
./args/t5_base_hotpot.json
{ "model_name_or_path": "google-t5/t5-base", "tokenizer_name_or_path": "google-t5/t5-base", "max_len": 512, "target_max_len": 163, "output_dir": "./model/t5-base/hotpotqa", "overwrite_output_dir": true, "per_device_train_batch_size": 16, "per_device_eval_batch_size": 16, "gradient_accumulation_steps": 8, "learning_rate": 0.0001, "warmup_steps": 500, "weight_decay": 0.01, "num_train_epochs": 10, "seed": 42, "do_train": true, "do_eval": true, "logging_steps": 100, "logging_dir": "./model/t5-base/logs", "train_file_path": "./data/data_t5/hotpotqa/pt/train.pt", "valid_file_path": "./data/data_t5/hotpotqa/pt/dev.pt", "remove_unused_columns": false, "prediction_loss_only": true, "load_best_model_at_end": true, "evaluation_strategy": "epoch", "save_strategy": "epoch" }
-
Run code in
T5.py
to train QG model# train model arg_path = './args/t5_base_hotpot.json' data_dir = 'directory of your data' # processed data train_path = data_dir + 'train.json' dev_path = data_dir + 'dev.json' test_path = data_dir + 'test.json' # optional train_qg(arg_path, train_path, dev_path, test_path)
-
Generate questions using QG models
We provide different types of QG models for question generation.
-
Use fine-tuned models in
lmqg
run
python apply_lmqg.py
to get the questions generated by models in lmqg.model_name = 'lmqg/flan-t5-base-squad-qg' data_json_path='./data/squad-test.json' save_path='./result/squad-test.csv' questions = apply_lmqg( model_name=model_name, data_json_path=data_json_path, save_path=save_path, batch_size=8 )
-
Use fine-tuned (including LoRA) models we provided for QG
Run specific code file for each type of model. For example, to use
T5-base_finetune
model to generate questions onHotpotQA Test Set
, run code inT5.py
.# test arg_path = './args/t5_base_hotpot.json' test_pt_path = './data/data_t5/hotpotqa/pt/test.pt' # tokenized test data model_dir = './model/t5-base/hotpotqa/' tokenizer_dir = model_dir + 'tokenizer/' result_save_path = './result/t5-base-finetune/hotpotqa/prediction.csv' decoded_texts = predict(model_dir, tokenizer_dir, test_pt_path, result_save_path)
-
Use few-shot mode
Run specific code file for each type of model. For example, to use
Flan-T5-xl_fewshot
model to generate questions onHotpotQA Test Set
, run code inFlan-T5.py
.# few-shot prediction model = 'google/flan-t5-large' few_shot_path = './data/squad1.1-few.json' # examples used for few-shot learning test_path = './data/squad-dev.json' result_save_path = './result/{}/prediction-few.csv'.format(model) dataset_name = 'hotpotqa' predict_few(model, few_shot_path, test_path, result_save_path, dataset_name)
-
Use GPTs
Run code in
generate_gpt.py
data_path = './test.xlsx' # your data for QG model = 'gpt-3.5-turbo' save_dir = './result/' few_shot_path = './data/hotpotqa-few.json' question_generation( data_path=data_path, model=model, save_dir=save_dir, few_shot_path=few_shot_path )
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for QGEval_qg-1.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fba9f12d1f430dfa768407b372d13f5ad0c188480cebca8d4ad5eb9df62a2a4a |
|
MD5 | 77819602cb8440ffb13da22981d383b6 |
|
BLAKE2b-256 | d8e21434e979e4298b4554df00f427709d7bb22d09e51e1c52c40b66f8bab62e |