Skip to main content

Genetic Algorithm: Optimize the finetuning of FLAN-T5 models

Project description

gentune-flan

Genetic Algorithm used to optimise the finetunning of FLAN-T5

library dependecies

The following libraries must be installed for the library gentune-flan to work

pip install evaluate, transformers, rouge-score, pandas, numpy, torch

the gentune-flan library

pip install gentune-flan

Example (Summery creation from FLAN-T5 Base Model)

import gentune-flan

dataset = dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="./hf_cache")

test_data = dataset["test"].select(range(100))

prompts = []

references = []

for ex in (list(test_data))[0:10]:

article = ex.get("article", "")

summary = ex.get("highlights", "")

if isinstance(article, str) and isinstance(summary, str) and article.strip() != "" and summary.strip() != "":

    prompt = f"Summarize the following article:\n\n{article.strip()}\n\nSummary:"
    
    prompts.append(prompt)
    
    references.append(summary.strip())

params = {

        "max_length": [200, 300,400,500,600,700],
        
        "temperature": [0.1,0.3,0.5,0.7,0.9,1],
        
        "top_k": [10,20,30,50,70,90,100],
        
        "top_p": [0.2,0.3,0.5,0.7,0.9],
        
        "do_sample": True,       # Enables sampling-based decoding
        
        "num_beams": 1           # Beam search disabled (greedy/sample)
        
    }

args = gentune-flan.optimize_flan_t5_base(prompts=prompts, references=references, params=params, error_metric='ROUGE-L', number_of_generation = 2, mutation_rate=0.05, population_size=4, random_seed=2025)

print(args)

model_name = "google/flan-t5-base" # or your fine-tuned path

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

inputs = tokenizer(

              prompts,
              
              return_tensors="pt",
              
              padding=True,
              
              truncation=True,
              
              max_length=512
          )

with torch.no_grad():

        outputs = model.generate(
        
            input_ids=inputs["input_ids"],
            
            attention_mask=inputs["attention_mask"],
            
            **args
        )

decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)

rouge = evaluate.load("rouge")

result = rouge.compute(predictions=decoded_preds, references=references, use_stemmer=True)

rouge_value = round(result['rougeL'] * 100, 2)

print(rouge_value)

Default Hyperparameter value range

max_length : [50,100,150,200,250,300]

temperature : [0.2,0.4,0.6,0.8]

top_k : [10,20,30,40,50]

top_p : [0.3,0.5,0.7,0.9]

error_metric = 'ROUGE-L' (user can check for 'ROUGE-1', 'ROUGE-2' also)

number_of_generation = 3 (user can increase the value for better evolution)

mutation_rate = 0.05

population_size = 10 (user can increase the value for larger set for better results)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gentune_flan-0.1.8.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gentune_flan-0.1.8-py3-none-any.whl (12.3 kB view details)

Uploaded Python 3

File details

Details for the file gentune_flan-0.1.8.tar.gz.

File metadata

  • Download URL: gentune_flan-0.1.8.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.12

File hashes

Hashes for gentune_flan-0.1.8.tar.gz
Algorithm Hash digest
SHA256 eeb8b7eef4c40c643726490790bfe8c13556085cfb2be1c03d5caeea7d43a85e
MD5 e531bb1d1ddb2c80a7e47274721f0b61
BLAKE2b-256 18b5f8c145d3a0219383584872bf152937f095ad410cfeeb75b65874e57a2302

See more details on using hashes here.

File details

Details for the file gentune_flan-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: gentune_flan-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 12.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.12

File hashes

Hashes for gentune_flan-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 667be47606858890c8c4fa3ff3e40b3cfbb53e4c8b77e46e3e9123bdcb9ffa5c
MD5 a10b08757d75eef59b988faad6d5f7f1
BLAKE2b-256 a4e5d3757da89453ace73f6ad914022c084da8dac10fc75f2ea4a9cb540b1c7c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page