A complete pipeline for fine-tuning OpenAI's Whisper ASR model using PyTorch Lightning
Project description
finetune-openai-whisper
A complete, production-ready pipeline for fine-tuning OpenAI's Whisper ASR model on custom datasets using PyTorch Lightning.
Table of Contents
- Features
- Requirements
- Installation
- Data Preparation
- Quick Start
- Configuration Reference
- Freezing Strategies In Depth
- Spectrogram Caching
- Monitoring Training
- Loading a Checkpoint for Inference
- Converting to Official Whisper Format
- Converting to Hugging Face Format
- Troubleshooting
- Acknowledgments
- License
Features
- All Whisper variants supported —
tiny,base,small,medium,large,large-v2,large-v3,turbo - Configurable freezing — freeze the full encoder, a specific number of encoder/decoder transformer blocks, or nothing at all
- Weight untying — optionally decouple the decoder's output projection from its token embedding so
lm_headcan adapt whiletoken_embeddingstays frozen - Spectrogram caching — optionally cache mel spectrograms to disk so subsequent epochs load instantly
- WER & CER evaluation — word and character error rates computed and logged at every validation step
- TensorBoard integration — all metrics streamed live during training
- Checkpoint management — automatically keeps the top-K best checkpoints ranked by validation WER
- Multi-GPU / DDP ready — includes the required fix for Whisper's sparse
alignment_headsbuffer - Single config object — every tunable parameter lives in one
Configdataclass; no scattered hard-coded values
Requirements
- Python >= 3.8
- PyTorch >= 2.0
- A CUDA-capable GPU is strongly recommended
All Python dependencies are installed automatically (see Installation).
Installation
Install the latest release from PyPI:
pip install finetune-openai-whisper
Or install directly from source for the latest development version:
git clone https://github.com/farisalasmary/finetune-openai-whisper
cd finetune-openai-whisper
pip install -e .
Data Preparation
Your training and validation data must be in JSONL format — one JSON object per line, where each object describes a single audio segment:
{"utt": "spk01_utt001", "audio_filepath": "/data/audio/spk01_utt001.wav", "text": "hello world", "duration": 3.2, "offset": 0.0}
{"utt": "spk01_utt002", "audio_filepath": "/data/audio/spk01_utt002.wav", "text": "how are you", "duration": 4.7, "offset": 0.0}
Field Descriptions
| Field | Type | Description |
|---|---|---|
utt |
str |
Unique utterance ID. Used as the cache filename when spectrogram caching is enabled. |
audio_filepath |
str |
Absolute or relative path to the audio file (WAV, MP3, FLAC, etc.). |
text |
str |
Ground truth transcription for this segment. |
duration |
float |
Duration of the audio segment in seconds. Used for duration filtering. |
offset |
float |
Start time offset within the file in seconds. Use 0.0 if the file contains only this segment. |
Duration Filtering
Segments shorter than min_duration or longer than max_duration are automatically skipped before training. The total sample count and hours are printed before and after filtering so you can verify your dataset.
Quick Start
from finetune_openai_whisper import Config
from finetune_openai_whisper.helpers import prepare_trainer_from_config
cfg = Config(
model_name="turbo",
lang="ar",
train_data="data/train.jsonl",
val_data="data/val.jsonl",
)
trainer, model, train_dl, val_dl = prepare_trainer_from_config(cfg)
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
That's it. The trainer will:
- Load and filter your datasets
- Initialise the Whisper model with the freezing strategy defined in
cfg - Apply the DDP sparse tensor fix automatically
- Run training with checkpointing, TensorBoard logging, and LR monitoring
Full Example Script
Below is a complete train.py you can copy, adapt, and run directly:
from finetune_openai_whisper import Config
from finetune_openai_whisper.helpers import prepare_trainer_from_config
cfg = Config(
# ── Model ─────────────────────────────────────────────────────────────
model_name="large-v3",
lang="ar",
# ── Data ──────────────────────────────────────────────────────────────
train_data="data/train.jsonl",
val_data="data/val.jsonl",
min_duration=1.0,
max_duration=30.0,
tmp_folder=None, # Set to a path to enable spectrogram caching
# ── Freezing ──────────────────────────────────────────────────────────
freeze_encoder=True, # Freeze the entire encoder (recommended)
num_frozen_encoder_layers=None,
freeze_decoder=True,
num_frozen_decoder_layers=20,
# ── Weight untying ────────────────────────────────────────────────────
untie_weights=True, # Decouple lm_head from token_embedding
# ── Training ──────────────────────────────────────────────────────────
num_train_epochs=50,
train_batch_size=16,
val_batch_size=8,
learning_rate=1e-5,
warmup_steps=500,
gradient_accumulation_steps=2, # Effective batch size = 16 × 2 = 32
# ── Hardware ──────────────────────────────────────────────────────────
accelerator="auto",
precision=16,
# ── Logging & checkpointing ───────────────────────────────────────────
log_dir="logs/",
logger_name="arabic_large_v1", # Change per experiment
save_top_k=3,
checkpoint_monitor="val_wer",
)
trainer, model, train_dl, val_dl = prepare_trainer_from_config(cfg)
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
Configuration Reference
All configuration is done through a single Config dataclass. Every field has a default value; only override what you need.
Model
| Parameter | Type | Default | Description |
|---|---|---|---|
model_name |
str |
"turbo" |
Whisper model variant: tiny, base, small, medium, large, large-v2, large-v3, turbo. |
lang |
str |
"ar" |
Target language code (e.g. "en", "ar", "fr", "zh"). |
Freezing Strategy
| Parameter | Type | Default | Description |
|---|---|---|---|
freeze_encoder |
bool |
True |
Freeze the encoder during training. |
num_frozen_encoder_layers |
Optional[int] |
None |
Encoder transformer blocks to freeze. See Freezing Strategies In Depth. |
freeze_decoder |
bool |
False |
Freeze the decoder during training. |
num_frozen_decoder_layers |
Optional[int] |
None |
Decoder transformer blocks to freeze. See Freezing Strategies In Depth. |
Weight Untying
| Parameter | Type | Default | Description |
|---|---|---|---|
untie_weights |
bool |
False |
Decouple the decoder's output projection (lm_head) from token_embedding. See Weight Untying. |
Data
| Parameter | Type | Default | Description |
|---|---|---|---|
train_data |
str |
"YOUR_TRAIN_DATA.jsonl" |
Path to the training JSONL file. |
val_data |
str |
"YOUR_VAL_DATA.jsonl" |
Path to the validation JSONL file. |
min_duration |
float |
5.0 |
Segments shorter than this (seconds) are skipped. |
max_duration |
float |
30.0 |
Segments longer than this (seconds) are skipped. |
sample_rate |
int |
16000 |
Audio sample rate. Whisper always expects 16 kHz — do not change this. |
tmp_folder |
Optional[str] |
None |
Directory for caching mel spectrograms. None disables caching. |
storage_threshold_gb |
float |
100.0 |
Minimum free disk space (GB) required before a spectrogram is cached. |
Optimizer
| Parameter | Type | Default | Description |
|---|---|---|---|
learning_rate |
float |
1e-5 |
Peak learning rate for AdamW. |
weight_decay |
float |
0.01 |
L2 regularisation applied to all parameters except biases and LayerNorm weights. |
adam_epsilon |
float |
1e-8 |
Epsilon for numerical stability in AdamW. |
warmup_steps |
int |
2000 |
Linear warmup steps before the LR reaches its peak. |
gradient_accumulation_steps |
int |
1 |
Effective batch size = train_batch_size × gradient_accumulation_steps. |
Training Loop
| Parameter | Type | Default | Description |
|---|---|---|---|
num_train_epochs |
int |
200 |
Maximum number of training epochs. |
seed |
int |
1415 |
Global random seed for reproducibility. |
train_batch_size |
int |
32 |
Samples per training batch. |
val_batch_size |
int |
16 |
Samples per validation batch. |
train_num_workers |
int |
32 |
DataLoader worker processes for training. |
val_num_workers |
int |
16 |
DataLoader worker processes for validation. |
Hardware
| Parameter | Type | Default | Description |
|---|---|---|---|
accelerator |
str |
"auto" |
PyTorch Lightning accelerator. "auto" picks a GPU if available, else CPU. |
precision |
int |
16 |
16 for mixed precision (recommended), 32 for full, "bf16-mixed" on Ampere+ GPUs. |
Logging
| Parameter | Type | Default | Description |
|---|---|---|---|
log_dir |
str |
"logs/" |
Root directory for TensorBoard logs. |
logger_name |
str |
"whisper_turbo_v1" |
Experiment sub-directory inside log_dir. Change per run. |
log_every_n_steps |
int |
1 |
How often (in optimiser steps) to write metrics to TensorBoard. |
lr_monitor_logging_interval |
str |
"epoch" |
LR logging frequency: "step" or "epoch". |
Checkpointing
| Parameter | Type | Default | Description |
|---|---|---|---|
checkpoint_dirpath |
str |
"logs/checkpoint" |
Directory for .ckpt files. |
checkpoint_filename |
str |
"whisper-finetuned-{epoch:04d}-{val_loss:.5f}-{val_wer:.5f}-{val_cer:.5f}" |
Filename template with metric placeholders. |
checkpoint_monitor |
str |
"val_wer" |
Metric used to rank and keep the best checkpoints. |
checkpoint_monitor_mode |
str |
"min" |
"min" for WER/CER/loss, "max" for accuracy. |
save_top_k |
int |
5 |
Number of best checkpoints to retain on disk. |
Freezing Strategies In Depth
Whisper is an encoder-decoder model. The encoder converts audio into dense representations; the decoder generates text tokens from those representations.
Encoder layout: conv1 → conv2 → blocks[0 … N-1] → ln_post
Decoder layout: token_embedding → blocks[0 … N-1] → ln
The four freezing parameters interact as follows:
Encoder
# Freeze the entire encoder (recommended starting point)
cfg = Config(freeze_encoder=True, num_frozen_encoder_layers=None)
# Freeze only the convolutional front-end (conv1, conv2) and ln_post;
# all transformer blocks remain trainable
cfg = Config(freeze_encoder=True, num_frozen_encoder_layers=0)
# Freeze the front-end + the first 4 transformer blocks
cfg = Config(freeze_encoder=True, num_frozen_encoder_layers=4)
# Full encoder fine-tuning — nothing is frozen
cfg = Config(freeze_encoder=False)
Decoder
# Decoder fully trainable (default)
cfg = Config(freeze_decoder=False)
# Freeze the entire decoder
cfg = Config(freeze_decoder=True, num_frozen_decoder_layers=None)
# Freeze only token_embedding and ln; all transformer blocks trainable
cfg = Config(freeze_decoder=True, num_frozen_decoder_layers=0)
# Freeze token_embedding, ln, and the first 2 decoder blocks
cfg = Config(freeze_decoder=True, num_frozen_decoder_layers=2)
Common Recipes
| Goal | Settings |
|---|---|
| Fast fine-tuning with minimal memory (default) | freeze_encoder=True, num_frozen_encoder_layers=None, freeze_decoder=False |
| Fine-tune the top encoder layers only | freeze_encoder=True, num_frozen_encoder_layers=N (bottom N frozen, rest trainable) |
| Full model fine-tuning (most data required) | freeze_encoder=False, freeze_decoder=False |
| Frozen encoder + frozen lower decoder layers | freeze_encoder=True, freeze_decoder=True, num_frozen_decoder_layers=N |
Note: The number of transformer blocks varies by model. For example,
turbohas 32 encoder blocks and 4 decoder blocks;basehas 6 and 6. Passing a value larger than the actual block count freezes all blocks without error.
Weight Untying
By default Whisper ties its decoder output projection to token_embedding — the same weight matrix is used for both input embedding lookup and output logit computation.
Setting untie_weights=True creates an independent lm_head Linear layer initialised from a copy of those weights, then patches the decoder's forward method to use lm_head for logit computation. This is most useful in combination with decoder freezing:
cfg = Config(
freeze_decoder=True,
num_frozen_decoder_layers=0, # freeze token_embedding + ln; blocks trainable
untie_weights=True, # lm_head is now an independent trainable projection
)
With this setup token_embedding stays frozen (preserving the model's vocabulary representations from pre-training) while lm_head can still adapt to the fine-tuning domain.
Spectrogram Caching
Computing mel spectrograms on the fly is CPU-intensive. Enabling caching saves significant time after the first epoch.
cfg = Config(
tmp_folder="tmp/spectrograms", # cache directory
storage_threshold_gb=50.0, # only cache if > 50 GB free
)
- On the first epoch, spectrograms are computed from audio and saved to
tmp_folderas.ptfiles (one per utterance, keyed byuttID). - On subsequent epochs, the cached
.ptfiles are loaded directly, bypassing audio decoding and FFT computation entirely. - If free disk space falls below
storage_threshold_gb, the spectrogram is computed on the fly and not cached — training continues safely even if the disk fills up. - Set
tmp_folder=None(default) to always compute spectrograms on the fly.
Monitoring Training
Start TensorBoard in a separate terminal to watch metrics live:
tensorboard --logdir=logs/
The following metrics are logged:
| Metric | When | Description |
|---|---|---|
train_loss |
Every step | Cross-entropy loss on the training batch |
val_loss |
Every validation step | Cross-entropy loss on the validation batch |
val_wer |
Every validation step | Word Error Rate across the validation batch |
val_cer |
Every validation step | Character Error Rate across the validation batch |
lr |
Every epoch (default) | Current learning rate from the scheduler |
Loading a Checkpoint for Inference
After training, load any saved .ckpt file for transcription:
from finetune_openai_whisper import Config, WhisperModelModule
cfg = Config(model_name="turbo", lang="ar")
model = WhisperModelModule.load_from_checkpoint(
"logs/checkpoint/whisper-finetuned-epoch=0010-....ckpt",
cfg=cfg,
model_name=cfg.model_name,
lang=cfg.lang,
)
model.eval()
result = model.model.transcribe("path/to/audio.wav")
print(result["text"])
Converting to Official Whisper Format
The PyTorch Lightning .ckpt format wraps model weights with training metadata. To use your fine-tuned model with the standard whisper.load_model() API, convert it first:
python -m finetune_openai_whisper.convert_ckpt_to_official_whisper_format \
turbo \
logs/checkpoint/your_checkpoint.ckpt \
whisper_turbo_finetuned.pt
After conversion, load as a standard Whisper model:
import whisper
model = whisper.load_model("whisper_turbo_finetuned.pt")
result = model.transcribe("audio.wav")
print(result["text"])
If you finetuned the model with untied weights, load the checkpoint as follows:
import torch
import whisper
from finetune_openai_whisper.helpers import untie_embed_n_output_weights
model = whisper.load_model('turbo')
untie_embed_n_output_weights(model)
finetuned_model_path = 'whisper_turbo_finetuned.pt'
state_dict = torch.load(finetuned_model_path)['model_state_dict']
model.load_state_dict(state_dict)
result = model.transcribe("audio.wav")
print(result["text"])
Converting to Hugging Face Format
To use your fine-tuned model with the 🤗 Transformers library:
- First convert to the official Whisper format as described above.
- Then use the Whisper checkpoint converter provided by Hugging Face Transformers:
Note: This script works only on finetuned models with tied weights.
python convert_openai_to_hf.py \
--checkpoint_path whisper_finetuned.pt \
--pytorch_dump_folder_path ./whisper-hf \
--convert_preprocessor True
Troubleshooting
CUDA Out of Memory
cfg = Config(
train_batch_size=8,
gradient_accumulation_steps=4, # maintains effective batch size of 32
precision=16,
freeze_encoder=True,
)
Slow Data Loading
cfg = Config(
train_num_workers=8, # tune to your CPU core count
tmp_folder="tmp/spectrograms", # enable caching to skip repeated FFT work
)
Poor Convergence
cfg = Config(
learning_rate=5e-6,
warmup_steps=500, # shorter warmup for smaller datasets
freeze_encoder=True,
)
Sparse Tensor Error with DDP
This is handled automatically by prepare_trainer_from_config(). The alignment_heads buffer in Whisper is sparse and incompatible with PyTorch DDP; it is converted to a dense tensor before training begins.
Checkpoint Not Found for Inference
Checkpoint filenames include metric values rendered at save time. Use a glob pattern to find the right file:
ls logs/checkpoint/whisper-finetuned-*.ckpt
Acknowledgments
- Training pipeline inspired by this Colab notebook
- WER/CER evaluation adapted from abjadai/catt
License
This project is licensed under the MIT License. See LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file finetune_openai_whisper-0.1.0.tar.gz.
File metadata
- Download URL: finetune_openai_whisper-0.1.0.tar.gz
- Upload date:
- Size: 28.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
de47147fd68bd13f5acf547b0331c93a8f4f12f78d758245d364c35fe08ffc5f
|
|
| MD5 |
97265a290f7a2027e7d82b2934128a61
|
|
| BLAKE2b-256 |
1a13375263c425d35643c6f42943320c4f61c2275c70737c71b4e593dad25b61
|
File details
Details for the file finetune_openai_whisper-0.1.0-py3-none-any.whl.
File metadata
- Download URL: finetune_openai_whisper-0.1.0-py3-none-any.whl
- Upload date:
- Size: 26.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d4b440378eade39250c5e257bb30d9970fcd9bd0b2311004b05aaabfd50b2581
|
|
| MD5 |
f21150ae067f3a605ca2d6bd54566f61
|
|
| BLAKE2b-256 |
499d884156240fe19f6955df8996b690a860b8303e2d3b7e54d210c10a84e27c
|