Skip to main content

Genetic Algorithm: Optimize the finetuning of FLAN-T5 models

Project description

gentune-flan

Genetic Algorithm used to optimise the finetunning of FLAN-T5

library dependecies

The following libraries must be installed for the library gentune-flan to work

pip install evaluate, transformers, rouge-score, pandas, numpy, torch

the gentune-flan library

pip install gentune-flan

Example (Summery creation from FLAN-T5 Base Model)

import gentune-flan

dataset = dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="./hf_cache")

test_data = dataset["test"].select(range(100))

prompts = [] references = []

for ex in (list(test_data))[0:10]: article = ex.get("article", "") summary = ex.get("highlights", "") if isinstance(article, str) and isinstance(summary, str) and article.strip() != "" and summary.strip() != "": prompt = f"Summarize the following article:\n\n{article.strip()}\n\nSummary:" prompts.append(prompt) references.append(summary.strip())

params = { "max_length": [200, 300,400,500,600,700], "temperature": [0.1,0.3,0.5,0.7,0.9,1], "top_k": [10,20,30,50,70,90,100], "top_p": [0.2,0.3,0.5,0.7,0.9], "do_sample": True, # Enables sampling-based decoding "num_beams": 1 # Beam search disabled (greedy/sample) }

args = gentune-flan.optimize_flan_t5_base(prompts=prompts, references=references, params=params, error_metric='ROUGE-L', number_of_generation = 2, mutation_rate=0.05, population_size=4, random_seed=2025) print(args)

model_name = "google/flan-t5-base" # or your fine-tuned path tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

inputs = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=512 )

with torch.no_grad(): outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], **args )

decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)

rouge = evaluate.load("rouge") result = rouge.compute(predictions=decoded_preds, references=references, use_stemmer=True)

rouge_value = round(result['rougeL'] * 100, 2) print(rouge_value)

Default Hyperparameter value range

max_length : [50,100,150,200,250,300] temperature : [0.2,0.4,0.6,0.8] top_k : [10,20,30,40,50] top_p : [0.3,0.5,0.7,0.9]

error_metric = 'ROUGE-L' (user can check for 'ROUGE-1', 'ROUGE-2' also) number_of_generation = 3 (user can increase the value for better evolution) mutation_rate = 0.05 population_size = 10 (user can increase the value for larger set for better results)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gentune_flan-0.1.7.tar.gz (7.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gentune_flan-0.1.7-py3-none-any.whl (12.3 kB view details)

Uploaded Python 3

File details

Details for the file gentune_flan-0.1.7.tar.gz.

File metadata

  • Download URL: gentune_flan-0.1.7.tar.gz
  • Upload date:
  • Size: 7.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.12

File hashes

Hashes for gentune_flan-0.1.7.tar.gz
Algorithm Hash digest
SHA256 41fb48144a692c82a737bfce4d8f2e62c90f079bd6dcbfec9b458b77993594da
MD5 44c334c5ca6d503a1e3a524852cafd16
BLAKE2b-256 20f997fa5368995f276a98133e34a54139f2c3a059dd24e6e2213ee8d61753c8

See more details on using hashes here.

File details

Details for the file gentune_flan-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: gentune_flan-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 12.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.12

File hashes

Hashes for gentune_flan-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 49be308fa1d8099a22d87e2f3261561d6f6e6527edf9136bdfe809c1077450b4
MD5 6ce2ac3995220f053391e11f577ac58e
BLAKE2b-256 402f56a7d086326c5b2a54d0d1c7b23a4502e78fc5831d06b4de59a03599320d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page