Skip to main content

E2-TTS - MLX

Project description

E2 TTS — MLX

Implementation of E2-TTS, Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS, with the MLX framework.

This implementation is based on the lucidrains implementation in Pytorch, which differs from the paper in that it uses a multistream transformer for text and audio, with conditioning done every transformer block.

Installation

pip install mlx-e2-tts

Usage

import mlx.core as mx

from e2_tts_mlx.model import E2TTS
from e2_tts_mlx.trainer import E2Trainer
from e2_tts_mlx.data import load_libritts_r

e2tts = E2TTS(
    tokenizer="char-utf8",  # or "phoneme_en"
    cond_drop_prob = 0.25,
    frac_lengths_mask = (0.7, 0.9),
    transformer = dict(
        dim = 1024,
        depth = 24,
        heads = 16,
        text_depth = 12,
        text_heads = 8,
        text_ff_mult = 4,
        max_seq_len = 4096,
        dropout = 0.1
    )
)
mx.eval(e2tts.parameters())

batch_size = 128
max_duration = 30

dataset = load_libritts_r(split="dev-clean")  # or any audio/caption dataset

trainer = E2Trainer(model = e2tts, num_warmup_steps = 1000)

trainer.train(
    train_dataset = ...,
    learning_rate = 7.5e-5,
    batch_size = batch_size
)

... after much training ...

cond = ...
text = ...
duration = ...  # from a trained DurationPredictor or otherwise

generated_mel_spec = e2tts.sample(
    cond = cond,
    text = text,
    duration = duration,
    steps = 32,
    cfg_strength = 1.0,  # if trained for cfg
)

Note the model size specified above (from the paper) is very large. See train_example.py for a more practical-sized model you can train on your local device.

Appreciation

lucidrains for the original implementation in Pytorch.

Citations

@inproceedings{Eskimez2024E2TE,
    title   = {E2 TTS: Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS},
    author  = {Sefik Emre Eskimez and Xiaofei Wang and Manthan Thakker and Canrun Li and Chung-Hsien Tsai and Zhen Xiao and Hemin Yang and Zirun Zhu and Min Tang and Xu Tan and Yanqing Liu and Sheng Zhao and Naoyuki Kanda},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:270738197}
}
@article{Burtsev2021MultiStreamT,
    title     = {Multi-Stream Transformers},
    author    = {Mikhail S. Burtsev and Anna Rumshisky},
    journal   = {ArXiv},
    year      = {2021},
    volume    = {abs/2107.10342},
    url       = {https://api.semanticscholar.org/CorpusID:236171087}
}

License

The code in this repository is released under the MIT license as found in the LICENSE file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_e2_tts-0.0.2.tar.gz (17.0 kB view hashes)

Uploaded Source

Built Distribution

mlx_e2_tts-0.0.2-py3-none-any.whl (16.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page