HuBERT (Hidden Unit BERT) implementation in MLX for Apple Silicon
Project description
MLX-HuBERT
A pure MLX implementation of HuBERT (Hidden Unit BERT) for Apple Silicon, providing efficient speech representation learning and automatic speech recognition.
Features
- 🚀 Optimized for Apple Silicon - Leverages MLX framework for efficient computation on M1/M2/M3 chips
- 🎯 Compatible with HuggingFace - Load pretrained HuBERT models from HuggingFace Hub
- 🔧 Easy to use - Simple API similar to Transformers
- 📊 Efficient - Faster inference compared to CPU-based implementations
- 🎤 Speech Recognition - Built-in CTC decoding for automatic speech recognition
Installation
pip install mlx-hubert
Or install from source:
git clone https://github.com/mzbac/mlx-hubert.git
cd mlx-hubert
pip install -e .
Quick Start
import mlx.core as mx
from mlx_hubert import load_model, HubertProcessor
from datasets import load_dataset
# Load processor and model
processor = HubertProcessor(sampling_rate=16000)
model, config = load_model("mzbac/hubert-large-ls960-ft")
# Load audio dataset
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
# Process audio
inputs = processor(ds[0]["audio"]["array"])
input_values = inputs["input_values"]
# Generate transcription
logits = model(input_values).logits
predicted_ids = mx.argmax(logits, axis=-1)
transcription = processor.decode(predicted_ids[0])
print(transcription)
# Output: "A MAN SAID TO THE UNIVERSE SIR I EXIST"
Model Architecture
MLX-HuBERT implements the full HuBERT architecture:
- Feature Encoder: Convolutional layers that process raw audio waveforms
- Feature Projection: Projects CNN features to transformer dimension
- Transformer Encoder: Self-attention layers for learning representations
- CTC Head: Linear layer for character/token prediction (ASR models)
Supported Models
Pre-converted Models on HuggingFace Hub
mzbac/hubert-large-ls960-ft- Large model fine-tuned for ASR
Converting Your Own Models
Use the included conversion script to convert any HuBERT model:
# Convert base model for feature extraction (automatically detected)
python convert_model.py --model facebook/hubert-base-ls960
# Convert CTC model for speech recognition (automatically detected)
python convert_model.py --model facebook/hubert-large-ls960-ft
The script automatically detects whether a model is a base model or CTC model from its configuration. The converted models will be saved in ./converted_models/ by default.
Advanced Usage
Batch Processing
# Process multiple audio samples
audio_samples = [ds[i]["audio"]["array"] for i in range(4)]
inputs = processor(audio_samples, padding=True)
input_values = inputs["input_values"]
attention_mask = inputs["attention_mask"]
outputs = model(input_values, attention_mask=attention_mask)
predictions = mx.argmax(outputs.logits, axis=-1)
transcriptions = processor.batch_decode(predictions)
Feature Extraction with Base Models
# Load base model for feature extraction
model, config = load_model("./converted_models/hubert-base-ls960")
processor = HubertProcessor.from_pretrained("./converted_models/hubert-base-ls960")
# Process audio
inputs = processor(audio_array)
input_values = inputs["input_values"]
# Extract features
outputs = model(input_values)
features = outputs.last_hidden_state # Shape: (batch, time, hidden_size)
# Get utterance-level embedding
utterance_embedding = mx.mean(features, axis=1) # Shape: (batch, hidden_size)
Custom Vocabulary
# Define custom vocabulary
vocab_dict = {
"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3,
" ": 4, "A": 5, "B": 6, # ... etc
}
processor = HubertProcessor(
vocab_dict=vocab_dict,
sampling_rate=16000
)
Model Usage
Using Pre-converted Models
The easiest way is to use models that have already been converted to safetensors format:
from mlx_hubert import load_model, HubertProcessor
# Load from HuggingFace Hub (already converted)
model, config = load_model("mzbac/hubert-large-ls960-ft")
processor = HubertProcessor.from_pretrained("mzbac/hubert-large-ls960-ft")
# Or load from local path
model, config = load_model("./converted_ctc_models")
processor = HubertProcessor.from_pretrained("./converted_ctc_models")
Converting HuggingFace Models
To convert a HuggingFace model to safetensors format for use with MLX:
Using the Command Line
# Convert a CTC model (auto-detects model type)
python convert_model.py --model facebook/hubert-large-ls960-ft
# Convert a base model
python convert_model.py --model facebook/hubert-base-ls960 --type base
# Convert to a specific directory
python convert_model.py --model facebook/hubert-large-ls960-ft --output ./my_model
# Convert without testing
python convert_model.py --model facebook/hubert-large-ls960-ft --no-test
Using the Python API
from mlx_hubert import convert_from_transformers
# Convert a model programmatically
model_path, config_path = convert_from_transformers(
"facebook/hubert-large-ls960-ft",
"./converted_model",
model_type="auto" # or "ctc", "base"
)
# Then load the converted model
from mlx_hubert import load_model, HubertProcessor
model, config = load_model("./converted_model")
processor = HubertProcessor.from_pretrained("./converted_model")
Direct PyTorch to MLX Conversion (Advanced)
For advanced users who want to convert models programmatically:
from transformers import HubertForCTC as HFHubertForCTC
from mlx_hubert import HubertForCTC, HubertConfig
from mlx_hubert.utils import load_pytorch_weights
# Load HuggingFace model
hf_model = HFHubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
# Create MLX config from HuggingFace config
config = HubertConfig.from_dict(hf_model.config.to_dict())
# Initialize MLX model
mlx_model = HubertForCTC(config)
# Load weights from PyTorch state dict
mlx_model = load_pytorch_weights(mlx_model, hf_model.state_dict(), config)
# Now you can use the model
model.eval()
Examples
Check the examples/ directory for:
simple_transcription.py- Basic speech recognitionspeech_recognition.py- Advanced examples with batching and streamingfeature_extraction.py- Extract speech representationsbase_model_example.py- Using base models for feature extraction and similarity
Development
Running Tests
pip install -e ".[dev]"
pytest tests/
Code Style
black mlx_hubert/
isort mlx_hubert/
flake8 mlx_hubert/
Citation
Original HuBERT paper:
@article{hsu2021hubert,
title={HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units},
author={Hsu, Wei-Ning and Bolte, Benjamin and Tsai, Yao-Hung Hubert and Lakhotia, Kushal and Salakhutdinov, Ruslan and Mohamed, Abdelrahman},
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
year={2021}
}
License
This project is licensed under the MIT License - see the LICENSE file for details.
Acknowledgments
- Thanks to the MLX team at Apple for the excellent framework
- The HuggingFace team for the Transformers implementation
- Meta AI Research for the original HuBERT model
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_hubert-0.1.0.tar.gz.
File metadata
- Download URL: mlx_hubert-0.1.0.tar.gz
- Upload date:
- Size: 32.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
477fea2402392d7bf07e4f6ac143e94194f05e291f9f4479066af2b0736b86de
|
|
| MD5 |
ce4bc64f418622ff79bb2bba12d3960f
|
|
| BLAKE2b-256 |
18018b01cf02f4a1957adebabfc77aa83af082bbf86a38f81b3c2aa6460fbfee
|
File details
Details for the file mlx_hubert-0.1.0-py3-none-any.whl.
File metadata
- Download URL: mlx_hubert-0.1.0-py3-none-any.whl
- Upload date:
- Size: 22.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ee59b6bb0cfd1b78f7a44f1f7a30460e53eb0a94c082bde0587847158ec3e182
|
|
| MD5 |
1169d9d5b6770a70f9d20d4fa6368d5f
|
|
| BLAKE2b-256 |
7cce7363b969e1b72d27cb55ac5cfed93315a5002dbb90d286c1b06cbb8dcbad
|