Skip to main content

TransFusion: Transcribing Speech with Multinomial Diffusion

Project description

TransFusion: Transcribing Speech with Multinomial Diffusion

The official code repo! This repo contains code for training, inference, and scoring of TransFusion ASR models from our paper, "TransFusion: Transcribing Speech with Multinomial Diffusion". The trained checkpoints are available under the 'Releases' tab, although the quickstart below will download them for you. Hope you find this useful!

Links:

TransFusion architecture

Figure: the TransFusion diagram showing both training and inference, as given in the paper.

Authors:

*equal contribution


Quickstart

We use torch hub to make model loading very easy -- no cloning of the repo needed! The steps to perform ASR inference with the trained checkpoint is simple:

  1. Instal pip dependancies: ensure torch, torchaudio, numpy, omegaconf, fairseq, fastprogress, jiwer, and pandas are installed (for full training dependencies see requirements.txt). Make sure you are using python 3.10 or above, this repo uses certain new features of python 3.10.
  2. Load models: load the trained TransFusion model and frozen WavLM encoder:
import torch
import torchaudio

device = 'cpu' # or 'cuda' if you have enough GPU memory.
wavlm = torch.hub.load('RF5/transfusion-asr', 'wavlm_large', device=device)
transfusion = torch.hub.load('RF5/transfusion-asr', 'transfusion_small_462k', device=device)
  1. Compute WavLM features: load a 16kHz waveform and compute the WavLM features:
path = '<path to arbitrary 16kHz waveform>.wav'
x, sr = torchaudio.load(pth)
assert sr == 16000
# get weighted WavLM features:
features = wavlm.extract_transfusion_features(x.to(device), wavlm) # (seq_len, dim)
  1. Predict transcript: Perform multinomial diffusion using all the additional techniques from the paper:
pred_inds, pred_text = transfusion.perform_simple_inference(
    transfusion, # pass in model to use in diffusion
    features[None],  # add batch dimension to features
    transfusion.diffuser, # diffuser containing diffusion parameters
    transfusion.vocab, # vocab for converting indices to text / text to indices
    transfusion.cfg # model/diffusion config dict
)
print(pred_text)
# prints out the predicted transcript of your utterance!

That's it, trivial! You can modify the diffusion parameters using the DSH class in transfusion/score.py and in the diffuser config. By default it uses the optimal settings found in the paper.

Checkpoints

Under the releases tab of this repo we provide two checkpoints:

  • The frozen WavLM encoder taken from the original WavLM authors, which we host here for convenience and torch hub integration.
  • The best TransFusion model presented in the paper, i.e. the model trained for 462k updates.

The performance on the Librispeech test set is summarized:

checkpoint Params (M) LS test-clean WER (%) LS test-other WER (%)
transfusion_small_462k 253 6.7 8.8

Training

For training you must also install deepspeed.

Preparing data

Before training, one needs to prepare the data. The steps to do that for the LibriSpeech dataset is:

  1. First download and extract the LibriSpeech dataset.

  2. Then extract the WavLM features with the extract.py script:

usage: python -m wavlm.extract [--librispeech_path PATH/TO/LIBRESPEECH] [--ckpt_path PATH/TO/WAVLM_LARGE_CKPT] [--out_path PATH/TO/FEAT]

required arguments:
    --librispeech_path          root path of librispeech dataset
    --out_path                  target directory to save WavLM features into
    --ckpt_path                 path to pretrained WavLM checkpoint

optional arguments:
    --seed 
    --device                    
  1. Split data into train, validation, and test splits using split_data.py script:
usage: split_data.py --librispeech_path LIBRISPEECH_PATH --ls_wavlm_path LS_WAVLM_PATH [--include_test]

Generate train & valid csvs from dataset directories

options:
  --librispeech_path LIBRISPEECH_PATH
                        path to root of librispeech dataset
  --ls_wavlm_path LS_WAVLM_PATH
                        path to root of WavLM features extracted using extract.py
  --include_test        include processing and saving test.csv for test subsets

Running this will save the train/valid/test csv files and a vocabulary dict as vocab.pt into a ./splits/ folder.

Now you are ready to get training!

Training

The training, model, and distributed computing config is specified in transfusion/config, deepspeed_cfg.json, and train.py. To train the model according to the paper specification, use the following deepspeed command to train using train.py:

deepspeed --num_nodes 1 train.py train_csv=splits/train.csv valid_csv=splits/valid.csv  checkpoint_path=runs/pog-debug/ vocab_path=splits/vocab.pt batch_size=12  --deepspeed --deepspeed_config=deepspeed_cfg.json validation_interval=20000 checkpoint_interval=20000

That's it! Now both logs and checkpoints will be saved into the checkpoint_path and the output_path specified in deepspeed_cfg.json.

You can get a detailed score of a trained checkpoint using the transfusion/score.py script (see its help message for usage), which is what is used to perform the final Librispeech evaluations. It contains all the special decoding strategies introduced in the paper as well as the main decoding hyperparameters.

Repository structure:

The repository is organized as follows:

├── transfusion
│   ├── config.py                   # hyperparameters
│   ├── dataset.py                  # data loading and processing
│   ├── diffusion.py                # diffusion helper functions
│   ├── eval.py                     # logging and evaluation metrics
│   ├── model.py                    # model definition
│   ├── score.py                    # evaluation function
│   ├── utils.py                    # training helper functions
│   └── wavlm_modules.py            # wavlm model modules (from original WavLM repo)
├── wavlm
│   ├── extract.py                  # wavlm feature extraction script
│   ├── modules.py                  # wavlm helper functions (from original WavLM repo)
│   └── WavLM.py                    # wavlm modules (from original WavLM repo)
├── deepspeed_cfg.json              # deepspeed config
├── hubconf.py                      # torchhub integration
├── README.md
├── requirements.txt
├── split_data.py                   # splits data into train/valid/test subsets
├── train.py                        # main training script
└── TransFusion.png                 # TransFusion model

Acknowledgements

Parts of code for this project are adapted from the following repositories -- please make sure to check them out! Thank you to the authors of:

Citation

@inproceedings{baas2022transfusion,
  title={TransFusion: Transcribing Speech with Multinomial Diffusion},
  author={Baas, Matthew and Eloff, Kevin and Kamper, Herman},
  booktitle={SACAIR},
  year=2022
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transfusion-asr-0.1.0.tar.gz (52.9 kB view details)

Uploaded Source

Built Distribution

transfusion_asr-0.1.0-py3-none-any.whl (54.7 kB view details)

Uploaded Python 3

File details

Details for the file transfusion-asr-0.1.0.tar.gz.

File metadata

  • Download URL: transfusion-asr-0.1.0.tar.gz
  • Upload date:
  • Size: 52.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for transfusion-asr-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f9026f5450da4ef14be246dc208deefdb857da525092ab182e48c0b8d0a583c5
MD5 57fa538c69424a5c2338f029cf8ac761
BLAKE2b-256 0c618b241f67846260f6d24e2f6c10f5a380cc7ede4f0e7619280e18215367c6

See more details on using hashes here.

File details

Details for the file transfusion_asr-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for transfusion_asr-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ba7cb988feeb1670335f42e8accc2ca77a1508cbd5cfa5e86eabe35d7be0f3d2
MD5 8690a84574c3f0842e80953d08224f35
BLAKE2b-256 b859c64b155873375427662a6fbdec13833100e79d1c48ac1c258f5e7e922b51

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page