Run Whisper fine-tuning with ease.
Project description
Whisper Fine-Tuning (WFT)
WFT is a Python library for fine-tuning OpenAI's Whisper models on custom datasets. It simplifies the process of preparing datasets, fine-tuning models, and saving the results.
Features
- Easy dataset preparation and preprocessing
- Fine-tuning Whisper models using LoRA (Low-Rank Adaptation)
- Support for custom datasets and Hugging Face datasets
- Flexible configuration options
- Metric calculation (CER or WER)
Installation
pip install wft
Quick Start
Here's a simple example to fine-tune a Whisper model on a custom dataset:
from wft import WhisperFineTuner
outdir = "output"
ft = (
WhisperFineTuner(outdir)
.prepare_dataset(
"mozilla-foundation/common_voice_16_1",
src_subset="zh-TW",
src_audio_column="audio",
src_transcription_column="sentence",
)
.set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
.set_metric("cer")
.set_lora_config() # Use default LoRA config
.train() # Use default training arguments
.merge_and_save(f"{outdir}/merged_model")
)
Usage
1. Set Baseline Model and Prepare Dataset
You can prepare a dataset from a local source or use a pre-existing Hugging Face dataset:
ft = WhisperFineTuner(outdir).set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
ft.prepare_dataset(
"mozilla-foundation/common_voice_16_1",
src_subset="zh-TW",
src_audio_column="audio",
src_transcription_column="sentence",
)
2. Configure Fine-tuning
Set the evaluation metric and LoRA configuration:
ft.set_metric("cer")
ft.set_lora_config() # Use default LoRA config
3. Train the Model
Start the fine-tuning process:
ft.train() # Use default training arguments
4. Save the Fine-tuned Model
Merge the LoRA weights with the base model and save:
ft.merge_and_save(f"{outdir}/merged_model")
# or push to Hugging Face
ft.merge_and_push("username/merged_model")
Advanced Usage
Custom LoRA Configuration
You can customize the LoRA configuration:
from peft import LoraConfig
custom_lora_config = LoraConfig(
r=32,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
)
ft.set_lora_config(custom_lora_config)
Custom Training Arguments
Customize the training process:
from transformers import Seq2SeqTrainingArguments
custom_training_args = Seq2SeqTrainingArguments(
output_dir=outdir,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
learning_rate=1e-4,
num_train_epochs=3,
# Add more arguments as needed
)
ft.train(custom_training_args)
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
License
This project is licensed under the MIT License.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file wft-0.0.3.tar.gz
.
File metadata
- Download URL: wft-0.0.3.tar.gz
- Upload date:
- Size: 8.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 53974ff724843526447bfe4f083e1844e1970b9dabfb53f9c4145f982adfda9d |
|
MD5 | 8855d23a7659150cff7ecce518ff7a69 |
|
BLAKE2b-256 | 4da1ea289a479947ed23916d2b5754ef292d07ade91091b449e66c93d10a2671 |
File details
Details for the file wft-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: wft-0.0.3-py3-none-any.whl
- Upload date:
- Size: 8.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0c2716a39d5cece93172824c306ac3eb0db599ad754647b8577449b6892d2aa1 |
|
MD5 | 3a1e3ef807727f340513a5c8d78e386c |
|
BLAKE2b-256 | 1cc8f3fa8774758e7a5a6510b7aea1a963bc3aee46c74b3a709d8d7edf8b64f9 |