Genetic Algorithm: Optimize the finetuning of FLAN-T5 models
Project description
gentune-flan
Genetic Algorithm used to optimise the finetunning of FLAN-T5
library dependecies
The following libraries must be installed for the library gentune-flan to work
pip install evaluate, transformers, rouge-score, pandas, numpy, torch
the gentune-flan library
pip install gentune-flan
Example (Summery creation from FLAN-T5 Base Model)
import gentune-flan
dataset = dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="./hf_cache")
test_data = dataset["test"].select(range(100))
prompts = [] references = []
for ex in (list(test_data))[0:10]: article = ex.get("article", "") summary = ex.get("highlights", "") if isinstance(article, str) and isinstance(summary, str) and article.strip() != "" and summary.strip() != "": prompt = f"Summarize the following article:\n\n{article.strip()}\n\nSummary:" prompts.append(prompt) references.append(summary.strip())
params = { "max_length": [200, 300,400,500,600,700], "temperature": [0.1,0.3,0.5,0.7,0.9,1], "top_k": [10,20,30,50,70,90,100], "top_p": [0.2,0.3,0.5,0.7,0.9], "do_sample": True, # Enables sampling-based decoding "num_beams": 1 # Beam search disabled (greedy/sample) }
args = gentune-flan.optimize_flan_t5_base(prompts=prompts, references=references, params=params, error_metric='ROUGE-L', number_of_generation = 2, mutation_rate=0.05, population_size=4, random_seed=2025) print(args)
model_name = "google/flan-t5-base" # or your fine-tuned path tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
inputs = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=512 )
with torch.no_grad(): outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], **args )
decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
rouge = evaluate.load("rouge") result = rouge.compute(predictions=decoded_preds, references=references, use_stemmer=True)
rouge_value = round(result['rougeL'] * 100, 2) print(rouge_value)
Default Hyperparameter value range
max_length : [50,100,150,200,250,300] temperature : [0.2,0.4,0.6,0.8] top_k : [10,20,30,40,50] top_p : [0.3,0.5,0.7,0.9]
error_metric = 'ROUGE-L' (user can check for 'ROUGE-1', 'ROUGE-2' also) number_of_generation = 3 (user can increase the value for better evolution) mutation_rate = 0.05 population_size = 10 (user can increase the value for larger set for better results)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gentune_flan-0.1.6.tar.gz.
File metadata
- Download URL: gentune_flan-0.1.6.tar.gz
- Upload date:
- Size: 7.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a521b35f096019ec731ef544001fd34acac7dec2b2647472a4e6d7821f92b5a4
|
|
| MD5 |
96cf478ba8e74220e77d6adbe2c11aff
|
|
| BLAKE2b-256 |
2ab29b7292a5601dfad9c80bce4abf8b08a3f85d468894f824652c9120bdfca0
|
File details
Details for the file gentune_flan-0.1.6-py3-none-any.whl.
File metadata
- Download URL: gentune_flan-0.1.6-py3-none-any.whl
- Upload date:
- Size: 12.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
88526406184eb1e26f41e4168cc204ecbb95ea3220ac9d07e2a45a1b27d0590f
|
|
| MD5 |
e21eaa45e93a470dc276dfcac46aa0d7
|
|
| BLAKE2b-256 |
4cbd0a939875dd039886f937295cd602f44b3ed5e6403f971de8d141331dfb9b
|