Run Whisper fine-tuning with ease.
Project description
๐ฃ๏ธ Whisper Fine-Tuning (WFT)
WFT is a ๐ Python library designed to streamline the fine-tuning process of ๐ค OpenAI's Whisper models on custom datasets. It simplifies ๐ฆ dataset preparation, model ๐ ๏ธ fine-tuning, and result ๐พ saving.
โจ Features
- ๐ค Hugging Face Integration: Set your organization (or user) name, and everything syncs automatically to the ๐ค Hugging Face hub.
- ๐ Easy Dataset Preparation and Preprocessing: Quickly prepare and preprocess datasets for ๐ ๏ธ fine-tuning.
- ๐ง Fine-Tuning Using LoRA (Low-Rank Adaptation): Fine-tune Whisper models with efficient LoRA techniques.
- โ๏ธ Flexible Configuration Options: Customize various aspects of the fine-tuning process.
- ๐ Evaluation Metrics: Supports Character Error Rate (CER) or Word Error Rate (WER) for model evaluation.
- ๐ TensorBoard Logging: Track your training progress in real-time with TensorBoard.
- ๐ค Automatic Model Merging and Saving: Automatically merge fine-tuned weights and save the final model.
- ๐ Resume Training: Resume training seamlessly from interrupted runs.
๐ ๏ธ Installation
Install WFT using ๐ pip:
pip install wft
๐ Quick Start
Fine-tune a Whisper model on a custom dataset with just a few steps:
- ๐งฉ Select a Baseline Model: Choose a pre-trained Whisper model.
- ๐ต Select a Dataset: Use a dataset that includes ๐ง audio and โ๏ธ transcription columns.
- ๐๏ธโโ๏ธ Start Training: Use default training arguments to quickly fine-tune the model.
Here's an example:
from wft import WhisperFineTuner
id = "whisper-large-v3-turbo-zh-TW-test-1"
ft = (
WhisperFineTuner(id)
.set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
.prepare_dataset(
"mozilla-foundation/common_voice_16_1",
src_subset="zh-TW",
src_audio_column="audio",
src_transcription_column="sentence",
)
.train() # Use default training arguments
)
That's it! ๐ You can now fine-tune your Whisper model easily.
To enable ๐ค Hugging Face integration and push your training log and model to Hugging Face, set the org
parameter when initializing WhisperFineTuner
:
id = "whisper-large-v3-turbo-zh-TW-test-1"
org = "JacobLinCool" # Organization to push the model to Hugging Face
ft = (
WhisperFineTuner(id, org)
.set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
.prepare_dataset(
"mozilla-foundation/common_voice_16_1",
src_subset="zh-TW",
src_audio_column="audio",
src_transcription_column="sentence",
)
.train() # Use default training arguments
.merge_and_push() # Merge the model and push it to Hugging Face
)
This will automatically push your training logs ๐ and the fine-tuned model ๐ค to your Hugging Face account under the specified organization.
๐ Usage Guide
1๏ธโฃ Set Baseline Model and Prepare Dataset
You can use a local dataset or a dataset from ๐ค Hugging Face:
ft = (
WhisperFineTuner(id)
.set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
.prepare_dataset(
"mozilla-foundation/common_voice_16_1",
src_subset="zh-TW",
src_audio_column="audio",
src_transcription_column="sentence",
)
)
To upload the preprocessed dataset to Hugging Face:
ft.push_dataset("username/dataset_name")
You can also prepare or load an already processed dataset:
ft = (
WhisperFineTuner(id)
.set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
.prepare_dataset(
"username/preprocessed_dataset",
"mozilla-foundation/common_voice_16_1",
src_subset="zh-TW",
src_audio_column="audio",
src_transcription_column="sentence",
)
)
2๏ธโฃ Configure Fine-Tuning
Set the evaluation metric and ๐ง LoRA configuration:
ft.set_metric("cer") # Use CER (Character Error Rate) for evaluation
# Custom LoRA configuration to fine-tune different parts of the model
from peft import LoraConfig
custom_lora_config = LoraConfig(
r=32,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
)
ft.set_lora_config(custom_lora_config)
You can also set custom ๐ ๏ธ training arguments:
from transformers import Seq2SeqTrainingArguments
custom_training_args = Seq2SeqTrainingArguments(
output_dir=ft.dir,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
learning_rate=1e-4,
num_train_epochs=3,
)
ft.set_training_args(custom_training_args)
3๏ธโฃ Train the Model
To begin ๐๏ธโโ๏ธ fine-tuning:
ft.train()
4๏ธโฃ Save or Push the Fine-Tuned Model
Merge ๐ง LoRA weights with the baseline model and save it:
ft.merge_and_save(f"{ft.dir}/merged_model")
# Or push to Hugging Face
ft.merge_and_push("username/merged_model")
๐ฌ Advanced Usage
๐ง Custom LoRA Configuration
Adjust the LoRA configuration to fine-tune different model parts:
custom_lora_config = LoraConfig(
r=32,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
)
ft.set_lora_config(custom_lora_config)
โ๏ธ Custom Training Arguments
Specify custom ๐ ๏ธ training settings:
from transformers import Seq2SeqTrainingArguments
custom_training_args = Seq2SeqTrainingArguments(
output_dir=ft.dir,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
learning_rate=1e-4,
num_train_epochs=3,
)
ft.set_training_args(custom_training_args)
๐ Run Custom Actions After Steps Using .then()
Add actions to be executed after each step:
ft = (
WhisperFineTuner(id)
.set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
.then(lambda ft: print(f"{ft.baseline_model=}"))
.prepare_dataset(
"mozilla-foundation/common_voice_16_1",
src_subset="zh-TW",
src_audio_column="audio",
src_transcription_column="sentence",
)
.then(lambda ft: print(f"{ft.dataset=}"))
.set_metric("cer")
.then(lambda ft: setattr(ft.training_args, "num_train_epochs", 5))
.train()
)
๐ Resume Training From a Checkpoint
If training is interrupted, you can resume:
ft = (
WhisperFineTuner(id)
.set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
.prepare_dataset(
"mozilla-foundation/common_voice_16_1",
src_subset="zh-TW",
src_audio_column="audio",
src_transcription_column="sentence",
)
.set_metric("cer")
.train(resume=True)
)
โน๏ธ Note: If no checkpoint is found, training will start from scratch without failure.
๐ค Contributing
We welcome contributions! ๐ Feel free to submit a pull request.
๐ License
This project is licensed under the MIT License.
Why there are so many emojis in this README?
Because ChatGPT told me to do so. ๐ค๐
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file wft-0.0.8.tar.gz
.
File metadata
- Download URL: wft-0.0.8.tar.gz
- Upload date:
- Size: 11.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5e111ce1776a2e100cb6512b34209b43016124007994303a6de4533bdd760ad9 |
|
MD5 | ccd9f167ff7adb1c6c552f12e0a187b5 |
|
BLAKE2b-256 | e5a65419ad41069241e12cc420b0352acc14165a5e5ea5365d70549b699b6f47 |
File details
Details for the file wft-0.0.8-py3-none-any.whl
.
File metadata
- Download URL: wft-0.0.8-py3-none-any.whl
- Upload date:
- Size: 11.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 18a655e79466797622dbbd312e6d5711a13d5098f25d12e9922c69c078438b51 |
|
MD5 | 3f38dd7d933865d195a77010b9d37f42 |
|
BLAKE2b-256 | 491870a632033b9b454c9e594e400db7e65920ee83f5f09fab01cf611b8b3612 |