Skip to main content

Run Whisper fine-tuning with ease.

Project description

Whisper Fine-Tuning (WFT)

WFT is a Python library for fine-tuning OpenAI's Whisper models on custom datasets. It simplifies the process of preparing datasets, fine-tuning models, and saving the results.

Features

  • Easy dataset preparation and preprocessing
  • Fine-tuning Whisper models using LoRA (Low-Rank Adaptation)
  • Support for custom datasets and Hugging Face datasets
  • Flexible configuration options
  • Metric calculation (CER or WER)

Installation

pip install wft

Quick Start

Here's a simple example to fine-tune a Whisper model on a custom dataset:

from wft import WhisperFineTuner

outdir = "output"
ft = (
    WhisperFineTuner(outdir)
    .prepare_dataset(
        "mozilla-foundation/common_voice_16_1",
        src_subset="zh-TW",
        src_audio_column="audio",
        src_transcription_column="sentence",
    )
    .set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
    .set_metric("cer")
    .set_lora_config()  # Use default LoRA config
    .train()  # Use default training arguments
    .merge_and_save(f"{outdir}/merged_model")
)

Usage

1. Set Baseline Model and Prepare Dataset

You can prepare a dataset from a local source or use a pre-existing Hugging Face dataset:

ft = WhisperFineTuner(outdir).set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
ft.prepare_dataset(
    "mozilla-foundation/common_voice_16_1",
    src_subset="zh-TW",
    src_audio_column="audio",
    src_transcription_column="sentence",
)

2. Configure Fine-tuning

Set the evaluation metric and LoRA configuration:

ft.set_metric("cer")
ft.set_lora_config()  # Use default LoRA config

3. Train the Model

Start the fine-tuning process:

ft.train()  # Use default training arguments

4. Save the Fine-tuned Model

Merge the LoRA weights with the base model and save:

ft.merge_and_save(f"{outdir}/merged_model")

# or push to Hugging Face
ft.merge_and_push("username/merged_model")

Advanced Usage

Custom LoRA Configuration

You can customize the LoRA configuration:

from peft import LoraConfig

custom_lora_config = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
)

ft.set_lora_config(custom_lora_config)

Custom Training Arguments

Customize the training process:

from transformers import Seq2SeqTrainingArguments

custom_training_args = Seq2SeqTrainingArguments(
    output_dir=outdir,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    num_train_epochs=3,
    # Add more arguments as needed
)

ft.train(custom_training_args)

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wft-0.0.5.tar.gz (7.8 kB view details)

Uploaded Source

Built Distribution

wft-0.0.5-py3-none-any.whl (7.9 kB view details)

Uploaded Python 3

File details

Details for the file wft-0.0.5.tar.gz.

File metadata

  • Download URL: wft-0.0.5.tar.gz
  • Upload date:
  • Size: 7.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for wft-0.0.5.tar.gz
Algorithm Hash digest
SHA256 2ede5cb62e2a8622f12528320e06b6ac3e3d8bf584de2b44c229026009c1adfc
MD5 908af3ed7db7f2aaa0dc738d8bba603d
BLAKE2b-256 7ca700a1ae751f9396ed30a92536d1452f75da8e397c362565dce69477a98698

See more details on using hashes here.

File details

Details for the file wft-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: wft-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 7.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for wft-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 b449b7437fe4756407b2c48a8bbfc0f891d1d41040d9a542412e0ea49d006caa
MD5 5eb4a5c2540d90976fa8523ba79e591e
BLAKE2b-256 419febe5b58c4e74770c7762afa535b66b350e8c9d6d67b9e605aee5a73d5cda

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page