Skip to main content

Run Whisper fine-tuning with ease.

Project description

๐Ÿ—ฃ๏ธ Whisper Fine-Tuning (WFT)

WFT is a ๐Ÿ Python library designed to streamline the fine-tuning process of ๐Ÿค– OpenAI's Whisper models on custom datasets. It simplifies ๐Ÿ“ฆ dataset preparation, model ๐Ÿ› ๏ธ fine-tuning, and result ๐Ÿ’พ saving.

โœจ Features

  • ๐Ÿค— Hugging Face Integration: Set your organization (or user) name, and everything syncs automatically to the ๐Ÿค— Hugging Face hub.
  • ๐Ÿ“„ Easy Dataset Preparation and Preprocessing: Quickly prepare and preprocess datasets for ๐Ÿ› ๏ธ fine-tuning.
  • ๐Ÿ”ง Fine-Tuning Using LoRA (Low-Rank Adaptation): Fine-tune Whisper models with efficient LoRA techniques.
  • โš™๏ธ Flexible Configuration Options: Customize various aspects of the fine-tuning process.
  • ๐Ÿ“Š Evaluation Metrics: Supports Character Error Rate (CER) or Word Error Rate (WER) for model evaluation.
  • ๐Ÿ“ˆ TensorBoard Logging: Track your training progress in real-time with TensorBoard.
  • ๐Ÿค– Automatic Model Merging and Saving: Automatically merge fine-tuned weights and save the final model.
  • ๐Ÿ”„ Resume Training: Resume training seamlessly from interrupted runs.

๐Ÿ› ๏ธ Installation

Install WFT using ๐Ÿ pip:

pip install wft

๐Ÿš€ Quick Start

Fine-tune a Whisper model on a custom dataset with just a few steps:

  1. ๐Ÿงฉ Select a Baseline Model: Choose a pre-trained Whisper model.
  2. ๐ŸŽต Select a Dataset: Use a dataset that includes ๐ŸŽง audio and โœ๏ธ transcription columns.
  3. ๐Ÿ‹๏ธโ€โ™‚๏ธ Start Training: Use default training arguments to quickly fine-tune the model.

Here's an example:

from wft import WhisperFineTuner

id = "whisper-large-v3-turbo-zh-TW-test-1"

ft = (
    WhisperFineTuner(id)
    .set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
    .prepare_dataset(
        "mozilla-foundation/common_voice_16_1",
        src_subset="zh-TW",
        src_audio_column="audio",
        src_transcription_column="sentence",
    )
    .train()  # Use default training arguments
)

That's it! ๐ŸŽ‰ You can now fine-tune your Whisper model easily.

To enable ๐Ÿค— Hugging Face integration and push your training log and model to Hugging Face, set the org parameter when initializing WhisperFineTuner:

id = "whisper-large-v3-turbo-zh-TW-test-1"
org = "JacobLinCool"  # Organization to push the model to Hugging Face

ft = (
    WhisperFineTuner(id, org)
    .set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
    .prepare_dataset(
        "mozilla-foundation/common_voice_16_1",
        src_subset="zh-TW",
        src_audio_column="audio",
        src_transcription_column="sentence",
    )
    .train()  # Use default training arguments
    .merge_and_push()  # Merge the model and push it to Hugging Face
)

This will automatically push your training logs ๐Ÿ“œ and the fine-tuned model ๐Ÿค– to your Hugging Face account under the specified organization.

๐Ÿ“š Usage Guide

1๏ธโƒฃ Set Baseline Model and Prepare Dataset

You can use a local dataset or a dataset from ๐Ÿค— Hugging Face:

ft = (
    WhisperFineTuner(id)
    .set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
    .prepare_dataset(
        "mozilla-foundation/common_voice_16_1",
        src_subset="zh-TW",
        src_audio_column="audio",
        src_transcription_column="sentence",
    )
)

To upload the preprocessed dataset to Hugging Face:

ft.push_dataset("username/dataset_name")

You can also prepare or load an already processed dataset:

ft = (
    WhisperFineTuner(id)
    .set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
    .prepare_dataset(
        "username/preprocessed_dataset",
        "mozilla-foundation/common_voice_16_1",
        src_subset="zh-TW",
        src_audio_column="audio",
        src_transcription_column="sentence",
    )
)

2๏ธโƒฃ Configure Fine-Tuning

Set the evaluation metric and ๐Ÿ”ง LoRA configuration:

ft.set_metric("cer")  # Use CER (Character Error Rate) for evaluation

# Custom LoRA configuration to fine-tune different parts of the model
from peft import LoraConfig

custom_lora_config = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
)

ft.set_lora_config(custom_lora_config)

You can also set custom ๐Ÿ› ๏ธ training arguments:

from transformers import Seq2SeqTrainingArguments

custom_training_args = Seq2SeqTrainingArguments(
    output_dir=ft.dir,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    num_train_epochs=3,
)

ft.set_training_args(custom_training_args)

3๏ธโƒฃ Train the Model

To begin ๐Ÿ‹๏ธโ€โ™‚๏ธ fine-tuning:

ft.train()

4๏ธโƒฃ Save or Push the Fine-Tuned Model

Merge ๐Ÿ”ง LoRA weights with the baseline model and save it:

ft.merge_and_save(f"{ft.dir}/merged_model")

# Or push to Hugging Face
ft.merge_and_push("username/merged_model")

๐Ÿ”ฌ Advanced Usage

๐Ÿ”ง Custom LoRA Configuration

Adjust the LoRA configuration to fine-tune different model parts:

custom_lora_config = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
)

ft.set_lora_config(custom_lora_config)

โš™๏ธ Custom Training Arguments

Specify custom ๐Ÿ› ๏ธ training settings:

from transformers import Seq2SeqTrainingArguments

custom_training_args = Seq2SeqTrainingArguments(
    output_dir=ft.dir,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    num_train_epochs=3,
)

ft.set_training_args(custom_training_args)

๐Ÿ” Run Custom Actions After Steps Using .then()

Add actions to be executed after each step:

ft = (
    WhisperFineTuner(id)
    .set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
    .then(lambda ft: print(f"{ft.baseline_model=}"))
    .prepare_dataset(
        "mozilla-foundation/common_voice_16_1",
        src_subset="zh-TW",
        src_audio_column="audio",
        src_transcription_column="sentence",
    )
    .then(lambda ft: print(f"{ft.dataset=}"))
    .set_metric("cer")
    .then(lambda ft: setattr(ft.training_args, "num_train_epochs", 5))
    .train()
)

๐Ÿ”„ Resume Training From a Checkpoint

If training is interrupted, you can resume:

ft = (
    WhisperFineTuner(id)
    .set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
    .prepare_dataset(
        "mozilla-foundation/common_voice_16_1",
        src_subset="zh-TW",
        src_audio_column="audio",
        src_transcription_column="sentence",
    )
    .set_metric("cer")
    .train(resume=True)
)

โ„น๏ธ Note: If no checkpoint is found, training will start from scratch without failure.

๐Ÿค Contributing

We welcome contributions! ๐ŸŽ‰ Feel free to submit a pull request.

๐Ÿ“œ License

This project is licensed under the MIT License.

Why there are so many emojis in this README?

Because ChatGPT told me to do so. ๐Ÿค–๐Ÿ“

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wft-0.0.8.tar.gz (11.0 kB view details)

Uploaded Source

Built Distribution

wft-0.0.8-py3-none-any.whl (11.4 kB view details)

Uploaded Python 3

File details

Details for the file wft-0.0.8.tar.gz.

File metadata

  • Download URL: wft-0.0.8.tar.gz
  • Upload date:
  • Size: 11.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for wft-0.0.8.tar.gz
Algorithm Hash digest
SHA256 5e111ce1776a2e100cb6512b34209b43016124007994303a6de4533bdd760ad9
MD5 ccd9f167ff7adb1c6c552f12e0a187b5
BLAKE2b-256 e5a65419ad41069241e12cc420b0352acc14165a5e5ea5365d70549b699b6f47

See more details on using hashes here.

File details

Details for the file wft-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: wft-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 11.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for wft-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 18a655e79466797622dbbd312e6d5711a13d5098f25d12e9922c69c078438b51
MD5 3f38dd7d933865d195a77010b9d37f42
BLAKE2b-256 491870a632033b9b454c9e594e400db7e65920ee83f5f09fab01cf611b8b3612

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page