Skip to main content

wavegrad

Project description

WaveGrad

PyPI Release License

WaveGrad is a fast, high-quality neural vocoder designed by the folks at Google Brain. The architecture is described in WaveGrad: Estimating Gradients for Waveform Generation. In short, this model takes a log-scaled Mel spectrogram and converts it to a waveform via iterative refinement.

Status (2020-10-15)

  • stable training (22 kHz, 24 kHz)
  • high-quality synthesis
  • mixed-precision training
  • multi-GPU training
  • custom noise schedule (faster inference)
  • command-line inference
  • programmatic inference API
  • PyPI package
  • audio samples
  • pretrained models
  • precomputed noise schedule

Audio samples

24 kHz audio samples

Pretrained models

24 kHz pretrained model (183 MB, SHA256: 65e9366da318d58d60d2c78416559351ad16971de906e53b415836c068e335f3)

Install

Install using pip:

pip install wavegrad

or from GitHub:

git clone https://github.com/lmnt-com/wavegrad.git
cd wavegrad
pip install .

Training

Before you start training, you'll need to prepare a training dataset. The dataset can have any directory structure as long as the contained .wav files are 16-bit mono (e.g. LJSpeech, VCTK). By default, this implementation assumes a sample rate of 22 kHz. If you need to change this value, edit params.py.

python -m wavegrad.preprocess /path/to/dir/containing/wavs
python -m wavegrad /path/to/model/dir /path/to/dir/containing/wavs

# in another shell to monitor training progress:
tensorboard --logdir /path/to/model/dir --bind_all

You should expect to hear intelligible speech by ~20k steps (~1.5h on a 2080 Ti).

Inference API

Basic usage:

from wavegrad.inference import predict as wavegrad_predict

model_dir = '/path/to/model/dir'
spectrogram = # get your hands on a spectrogram in [N,C,W] format
audio, sample_rate = wavegrad_predict(spectrogram, model_dir)

# audio is a GPU tensor in [N,T] format.

If you have a custom noise schedule (see below):

from wavegrad.inference import predict as wavegrad_predict

params = { 'noise_schedule': np.load('/path/to/noise_schedule.npy') }
model_dir = '/path/to/model/dir'
spectrogram = # get your hands on a spectrogram in [N,C,W] format
audio, sample_rate = wavegrad_predict(spectrogram, model_dir, params=params)

# `audio` is a GPU tensor in [N,T] format.

Inference CLI

python -m wavegrad.inference /path/to/model /path/to/spectrogram -o output.wav

Noise schedule

The default implementation uses 1000 iterations to refine the waveform, which runs slower than real-time. WaveGrad is able to achieve high-quality, faster than real-time synthesis with as few as 6 iterations without re-training the model with new hyperparameters.

To achieve this speed-up, you will need to search for a noise schedule that works well for your dataset. This implementation provides a script to perform the search for you:

python -m wavegrad.noise_schedule /path/to/trained/model /path/to/preprocessed/validation/dataset
python -m wavegrad.inference /path/to/trained/model /path/to/spectrogram -n noise_schedule.npy -o output.wav

The default settings should give good results without spending too much time on the search. If you'd like to find a better noise schedule or use a different number of inference iterations, run the noise_schedule script with --help to see additional configuration options.

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wavegrad-0.1.4.tar.gz (12.8 kB view details)

Uploaded Source

File details

Details for the file wavegrad-0.1.4.tar.gz.

File metadata

  • Download URL: wavegrad-0.1.4.tar.gz
  • Upload date:
  • Size: 12.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.1.post20200802 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.8.5

File hashes

Hashes for wavegrad-0.1.4.tar.gz
Algorithm Hash digest
SHA256 5665442773e8ffe4669e32ef1ead7294f0b09aa0b776549cc2fc9694eab9731a
MD5 df9bbcab01495b4a1d0566b5848c03a2
BLAKE2b-256 4ff13ea9437b53346fcdc446428ac9b16614ac51f089d2b5bfe708f77e4cf4db

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page