Skip to main content

Run Whisper fine-tuning with ease.

Project description

Whisper Fine-Tuning (WFT)

WFT is a Python library for fine-tuning OpenAI's Whisper models on custom datasets. It simplifies the process of preparing datasets, fine-tuning models, and saving the results.

Features

  • Easy dataset preparation and preprocessing
  • Fine-tuning Whisper models using LoRA (Low-Rank Adaptation)
  • Support for custom datasets and Hugging Face datasets
  • Flexible configuration options
  • Metric calculation (CER or WER)

Installation

pip install wft

Quick Start

Here's a simple example to fine-tune a Whisper model on a custom dataset:

from wft import WhisperFineTuner

outdir = "output"
ft = (
    WhisperFineTuner(outdir)
    .prepare_dataset(
        "mozilla-foundation/common_voice_16_1",
        src_subset="zh-TW",
        src_audio_column="audio",
        src_transcription_column="sentence",
    )
    .set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
    .set_metric("cer")
    .set_lora_config()  # Use default LoRA config
    .train()  # Use default training arguments
    .merge_and_save(f"{outdir}/merged_model")
)

Usage

1. Set Baseline Model and Prepare Dataset

You can prepare a dataset from a local source or use a pre-existing Hugging Face dataset:

ft = WhisperFineTuner(outdir).set_baseline("openai/whisper-large-v3-turbo", language="zh", task="transcribe")
ft.prepare_dataset(
    "mozilla-foundation/common_voice_16_1",
    src_subset="zh-TW",
    src_audio_column="audio",
    src_transcription_column="sentence",
)

2. Configure Fine-tuning

Set the evaluation metric and LoRA configuration:

ft.set_metric("cer")
ft.set_lora_config()  # Use default LoRA config

3. Train the Model

Start the fine-tuning process:

ft.train()  # Use default training arguments

4. Save the Fine-tuned Model

Merge the LoRA weights with the base model and save:

ft.merge_and_save(f"{outdir}/merged_model")

# or push to Hugging Face
ft.merge_and_push("username/merged_model")

Advanced Usage

Custom LoRA Configuration

You can customize the LoRA configuration:

from peft import LoraConfig

custom_lora_config = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
)

ft.set_lora_config(custom_lora_config)

Custom Training Arguments

Customize the training process:

from transformers import Seq2SeqTrainingArguments

custom_training_args = Seq2SeqTrainingArguments(
    output_dir=outdir,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    num_train_epochs=3,
    # Add more arguments as needed
)

ft.train(custom_training_args)

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wft-0.0.3.tar.gz (8.0 kB view details)

Uploaded Source

Built Distribution

wft-0.0.3-py3-none-any.whl (8.2 kB view details)

Uploaded Python 3

File details

Details for the file wft-0.0.3.tar.gz.

File metadata

  • Download URL: wft-0.0.3.tar.gz
  • Upload date:
  • Size: 8.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for wft-0.0.3.tar.gz
Algorithm Hash digest
SHA256 53974ff724843526447bfe4f083e1844e1970b9dabfb53f9c4145f982adfda9d
MD5 8855d23a7659150cff7ecce518ff7a69
BLAKE2b-256 4da1ea289a479947ed23916d2b5754ef292d07ade91091b449e66c93d10a2671

See more details on using hashes here.

File details

Details for the file wft-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: wft-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 8.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for wft-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 0c2716a39d5cece93172824c306ac3eb0db599ad754647b8577449b6892d2aa1
MD5 3a1e3ef807727f340513a5c8d78e386c
BLAKE2b-256 1cc8f3fa8774758e7a5a6510b7aea1a963bc3aee46c74b3a709d8d7edf8b64f9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page