Skip to main content

MMDS: A general-purpose multimodal dataset wrapper.

Project description

MMDS: A general-purpose multimodal dataset wrapper

This project is under construction, API may change from time to time.

Installation

Stable (not stable yet though)

pip install mmds

Latest

pip install mmds --pre

Example Usage

# example.py

import timeit
from pathlib import Path
from multiprocessing import Manager

from mmds import MultimodalDataset, MultimodalSample
from mmds.exceptions import PackageNotFoundError
from mmds.modalities.rgbs import RgbsModality
from mmds.modalities.wav import WavModality
from mmds.modalities.mel import MelModality
from mmds.modalities.f0 import F0Modality
from mmds.modalities.ge2e import Ge2eModality
from mmds.utils.spectrogram import LogMelSpectrogram


try:
    import youtube_dl
    import ffmpeg
    import torch
    from torchvision import transforms
except ImportError:
    raise PackageNotFoundError(
        "youtube_dl",
        "ffmpeg-python",
        "torch",
        "torchvision",
        by="example.py",
    )


def download():
    Path("data").mkdir(exist_ok=True)

    ydl_opts = {
        "postprocessors": [
            {
                "key": "FFmpegExtractAudio",
                "preferredcodec": "mp3",
                "preferredquality": "192",
            }
        ],
        "postprocessor_args": ["-ar", "16000"],
        "outtmpl": "data/%(id)s.%(ext)s",
        "keepvideo": True,
    }
    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        ydl.download(["https://www.youtube.com/watch?v=BaW_jenozKc"])

    path = Path("data/BaW_jenozKc")

    if not path.exists():
        path.mkdir(exist_ok=True)

        (
            ffmpeg.input("data/BaW_jenozKc.mp4")
            .filter("fps", fps="25")
            .output("data/BaW_jenozKc/%06d.png", start_number=0)
            .overwrite_output()
            .run(quiet=True)
        )


class MyMultimodalSample(MultimodalSample):
    def generate_info(self):
        wav_modality = self.get_modality_by_name("wav")
        rgbs_modality = self.get_modality_by_name("rgbs")
        return dict(
            t0=0,
            t1=wav_modality.duration / 10,
            original_wav_seconds=wav_modality.duration,
            original_rgbs_seconds=rgbs_modality.duration,
        )


class MyMultimodalDataset(MultimodalDataset):
    Sample = MyMultimodalSample


def main():
    download()

    # optional multiprocessing cache manager
    manager = Manager()

    dataset = MyMultimodalDataset(
        ["BaW_jenozKc"],
        modality_factories=[
            RgbsModality.create_factory(
                name="rgbs",
                root="data",
                suffix="*.png",
                sample_rate=25,
                transform=transforms.Compose(
                    [
                        transforms.Resize((28, 28)),
                        transforms.ToTensor(),
                        transforms.Normalize(0.5, 1),
                    ],
                ),
                aggragate=torch.stack,
                cache=manager.dict(),
            ),
            WavModality.create_factory(
                name="wav",
                root="data",
                suffix=".mp3",
                sample_rate=16_000,
                cache=manager.dict(),
            ),
            MelModality.create_factory(
                name="mel",
                root="data",
                suffix=".mel.npz",
                mel_fn=LogMelSpectrogram(sample_rate=16_000),
                base_modality_name="wav",
                cache=manager.dict(),
            ),
            F0Modality.create_factory(
                name="f0",
                root="data",
                suffix=".f0.npz",
                mel_fn=LogMelSpectrogram(sample_rate=16_000),
                base_modality_name="wav",
                cache=manager.dict(),
            ),
            Ge2eModality.create_factory(
                name="ge2e",
                root="data",
                suffix=".ge2e.npz",
                sample_rate=16_000,
                base_modality_name="wav",
                cache=manager.dict(),
                fetching=False,
            ),
        ],
    )

    # first load
    print(timeit.timeit(lambda: dataset[0], number=1))

    # second load
    print(timeit.timeit(lambda: dataset[0], number=1))

    print(dataset[0]["info"])

    for key, value in dataset[0].items():
        try:
            print(key, value.shape, type(value))
        except:
            pass


if __name__ == "__main__":
    main()

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mmds-0.0.1.dev20211222145341.tar.gz (12.6 kB view details)

Uploaded Source

Built Distribution

mmds-0.0.1.dev20211222145341-py3-none-any.whl (16.0 kB view details)

Uploaded Python 3

File details

Details for the file mmds-0.0.1.dev20211222145341.tar.gz.

File metadata

  • Download URL: mmds-0.0.1.dev20211222145341.tar.gz
  • Upload date:
  • Size: 12.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for mmds-0.0.1.dev20211222145341.tar.gz
Algorithm Hash digest
SHA256 aed44ec0872de9cc678176b9fa62bf939f7f5215ce6759abbd85c69abf858b2f
MD5 f8e4ceb3c99b89bc4be258bbe3a2afb1
BLAKE2b-256 3a0519c9099b7be03bf98558a3997b2c094aa0bc92484906356262e9ece78a26

See more details on using hashes here.

File details

Details for the file mmds-0.0.1.dev20211222145341-py3-none-any.whl.

File metadata

  • Download URL: mmds-0.0.1.dev20211222145341-py3-none-any.whl
  • Upload date:
  • Size: 16.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for mmds-0.0.1.dev20211222145341-py3-none-any.whl
Algorithm Hash digest
SHA256 b3d7807f000106d6d35f47824eba08006d86ff509278ef6476e562b7311f3adf
MD5 04311cc24c63e5da7d6ecc60d5966a6d
BLAKE2b-256 6bb18e10e14e0390a7e0404378b102ea8ec246b65e1198a9721197c8751c623c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page