Skip to main content

Improve Whisper with RoPE and latest tokenizers of OpenAI

Project description

NeoWhisper

Improve Whisper of OpenAI by integrating Rotary Positional Embeddings (RoPE) and adding more options for tokenizers available in pypi package tiktoken.

Support My Work

While this work comes truly from the heart, each project represents a significant investment of time -- from deep-dive research and code preparation to the final narrative and editing process. I am incredibly passionate about sharing this knowledge, but maintaining this level of quality is a major undertaking. If you find my work helpful and are in a position to do so, please consider supporting my work with a donation. You can click here to donate or scan the QR code below. Your generosity acts as a huge encouragement and helps ensure that I can continue creating in-depth, valuable content for you.

Using Cambodian bank account, you can donate by scanning my ABA QR code here. (or click here. Make sure that receiver's name is 'Khun Kim Ang'.)

Installation

pip install neo-whisper

Requirement

pip install git+https://github.com/openai/whisper.git

Usage

Loading tokenizer

from neo_whisper import get_tokenizer
tokenizer_name = 'cl100k_base'
tokenizer = get_tokenizer(multilingual=True, language='km', task='transcribe', encoder_name=tokenizer_name)
print(tokenizer.eot)

Loading NeoWhisper model

from neo_whisper import NeoWhisper, NeoModelDimensions
dims = NeoModelDimensions(
    n_vocab=tokenizer.encoding.n_vocab, # use the tokenizer's vocab size
    n_mels=80,
    n_audio_ctx=1500,
    n_audio_state=384,
    n_audio_head=6,
    n_audio_layer=4,
    n_text_ctx=448,
    n_text_state=384,
    n_text_head=6,
    n_text_kv_head=6,
    n_text_layer=4
)
model = NeoWhisper(dims)

This model works like the original model of OpenAI whisper (actually, NeoWhisper inherits from Whisper of openai-whisper. TextDecoder of NeoWhisper is different from the one of Whisper in the sense that RoPE is integrated in NeoWhisper.).

Loading Original Whisper model

It is possible to load the model implemented in openai-whisper but with new tokenizer (such as cl100k_base).

from neo_whisper import Whisper, ModelDimensions
dims = ModelDimensions(
    n_vocab=tokenizer.encoding.n_vocab, # use the tokenizer's vocab size
    n_mels=80,
    n_audio_ctx=1500,
    n_audio_state=384,
    n_audio_head=6,
    n_audio_layer=4,
    n_text_ctx=448,
    n_text_state=384,
    n_text_head=6,
    n_text_layer=4
)
model = Whisper(dims)

NOTE: When using new tokenizer, you need to train the Text Decoder of your model.

Train TextDecoder

You can check out the notebook below to train your own NeoWhisper. I would like to highlight that you can use your own tokenizer as long as it is available in tiktoken pypi package to train NeoWhisper and I recommend to do so for Khmer language.

Open in Colab

I also have a video about training Text Decoder of NeoWhisper below

Watch the video

Remark

When the config of AudioEncoder is the same as the original whisper audio encoder trained by OpenAI, we can load pre-trained weight for the encoder from OpenAI, and just train the text decoder. To load model with AudioEncoder of OpenAI whisper, simply provide neo_encoder=False when initialize NeoWhisper (by default, neo_encoder=True).

from neo_whisper import NeoWhisper, NeoModelDimensions
import whisper

dims = NeoModelDimensions(
    n_vocab=tokenizer.encoding.n_vocab, # use the tokenizer's vocab size
    n_mels=80,
    n_audio_ctx=1500,
    n_audio_state=384,
    n_audio_head=6,
    n_audio_layer=4,
    n_text_ctx=448,
    n_text_state=384,
    n_text_head=6,
    n_text_kv_head=6,
    n_text_layer=4
)
model = NeoWhisper(dims, neo_encoder=False)
# load pre-trained weight of audio encoder
model.encoder.load_state_dict(whisper.load_model("tiny").encoder.state_dict())
# freeze the pre-trained weight
for p in model.encoder.parameters():
    p.requires_grad = False

Transcription

We can use trained model for transcription in the same way as openai-whisper pypi. The only difference is that you must specify tokenizer_name properly. Concretely, tokenizer used in the transcription task must be the tokenizer used to train the model. So, tokenizer_name must be provided in the arguments of transcribe.

from neo_whisper import (
    get_tokenizer,
    NeoWhisper,
    NeoModelDimensions,
    transcribe
)
tokenizer_name = 'cl100k_base'
tokenizer = get_tokenizer(multilingual=True, task='transcribe', encoder_name=tokenizer_name)
dims = NeoModelDimensions(
    n_vocab=tokenizer.encoding.n_vocab, # use the tokenizer's vocab size
    n_mels=80,       # or whatever context size you're training with
    n_audio_ctx=1500,
    n_audio_state=384,
    n_audio_head=6,
    n_audio_layer=4,
    n_text_ctx=448,
    n_text_state=384,
    n_text_head=6,
    n_text_kv_head=6,
    n_text_layer=4
)
model = NeoWhisper(dims, neo_encoder=False) # if you use neo_encoder, specify accordingly
best_model_params_path = "path/to/your/weights.pt"
model.load_state_dict(torch.load(best_model_params_path))

result = transcribe(wmodel, audio_array, verbose=True, tokenizer_name=tokenizer_name)
print(result['text'])

TODO:

  • implement decoding function for NeoWhisper and Whisper
  • implement transcription for NeoWhisper and Whisper
  • notebook colab for training NeoWhisper
  • benchmarking

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neo_whisper-0.1.3.tar.gz (28.0 kB view details)

Uploaded Source

File details

Details for the file neo_whisper-0.1.3.tar.gz.

File metadata

  • Download URL: neo_whisper-0.1.3.tar.gz
  • Upload date:
  • Size: 28.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for neo_whisper-0.1.3.tar.gz
Algorithm Hash digest
SHA256 d1f0d4be2359cc161f31d05e976fa3f2d4cc793894b748ddb63d5868789d6ac4
MD5 f2f3ab2a7c220c36525338270aa2cab0
BLAKE2b-256 12e9e3dd1c6adfa00aa8e7cb7bcd7dd9998b7bc6129992070f060a32feadbb50

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page