Skip to main content

MMDS: A general-purpose multimodal dataset wrapper.

Project description

MMDS: A general-purpose multimodal dataset wrapper

This project is under construction, API may change from time to time.

Installation

Stable (not stable yet though)

pip install mmds

Latest

pip install mmds --pre

Example Usage

# example.py

import timeit
from mmds import MultimodalDataset, MultimodalSample
from mmds.modalities import (
    RgbsModality,
    WavModality,
    MelModality,
    F0Modality,
    Ge2eModality,
)
from mmds.utils.spectrogram import LogMelSpectrogram
from pathlib import Path
from multiprocessing import Manager


try:
    import youtube_dl
    import ffmpeg
    import torch
    from torchvision import transforms
except:
    raise ImportError(
        "This demo requires youtube_dl, ffmpeg-python and torch torchvision, "
        "install them now: pip install youtube_dl ffmpeg-python torch torchvision"
    )


def download():
    Path("data").mkdir(exist_ok=True)

    ydl_opts = {
        "postprocessors": [
            {
                "key": "FFmpegExtractAudio",
                "preferredcodec": "mp3",
                "preferredquality": "192",
            }
        ],
        "postprocessor_args": ["-ar", "16000"],
        "outtmpl": "data/%(id)s.%(ext)s",
        "keepvideo": True,
    }
    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        ydl.download(["https://www.youtube.com/watch?v=BaW_jenozKc"])

    path = Path("data/BaW_jenozKc")

    if not path.exists():
        path.mkdir(exist_ok=True)

        (
            ffmpeg.input("data/BaW_jenozKc.mp4")
            .filter("fps", fps="25")
            .output("data/BaW_jenozKc/%06d.png", start_number=0)
            .overwrite_output()
            .run(quiet=True)
        )


class MyMultimodalSample(MultimodalSample):
    def generate_info(self):
        wav_modality = self.get_modality_by_name("wav")
        rgbs_modality = self.get_modality_by_name("rgbs")
        return dict(
            t0=0,
            t1=wav_modality.duration / 10,
            original_wav_seconds=wav_modality.duration,
            original_rgbs_seconds=rgbs_modality.duration,
        )


class MyMultimodalDataset(MultimodalDataset):
    Sample = MyMultimodalSample


def main():
    download()

    # optional multiprocessing cache manager
    manager = Manager()

    dataset = MyMultimodalDataset(
        ["BaW_jenozKc"],
        modality_factories=[
            RgbsModality.create_factory(
                name="rgbs",
                root="data",
                suffix="*.png",
                sample_rate=25,
                transform=transforms.Compose(
                    [
                        transforms.Resize((28, 28)),
                        transforms.ToTensor(),
                        transforms.Normalize(0.5, 1),
                    ],
                ),
                aggragate=torch.stack,
                cache=manager.dict(),
            ),
            WavModality.create_factory(
                name="wav",
                root="data",
                suffix=".mp3",
                sample_rate=16_000,
                cache=manager.dict(),
            ),
            MelModality.create_factory(
                name="mel",
                root="data",
                suffix=".mel.npz",
                mel_fn=LogMelSpectrogram(sample_rate=16_000),
                base_modality_name="wav",
                cache=manager.dict(),
            ),
            F0Modality.create_factory(
                name="f0",
                root="data",
                suffix=".f0.npz",
                mel_fn=LogMelSpectrogram(sample_rate=16_000),
                base_modality_name="wav",
                cache=manager.dict(),
            ),
            Ge2eModality.create_factory(
                name="ge2e",
                root="data",
                suffix=".ge2e.npz",
                sample_rate=16_000,
                base_modality_name="wav",
                cache=manager.dict(),
            ),
        ],
    )

    # first load
    print(timeit.timeit(lambda: dataset[0], number=1))

    # second load
    print(timeit.timeit(lambda: dataset[0], number=1))

    print(dataset[0]["info"])

    for key, value in dataset[0].items():
        try:
            print(key, value.shape, type(value))
        except:
            pass


if __name__ == "__main__":
    main()

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mmds-0.0.1.dev20211003212456.tar.gz (13.3 kB view details)

Uploaded Source

Built Distribution

mmds-0.0.1.dev20211003212456-py3-none-any.whl (16.4 kB view details)

Uploaded Python 3

File details

Details for the file mmds-0.0.1.dev20211003212456.tar.gz.

File metadata

  • Download URL: mmds-0.0.1.dev20211003212456.tar.gz
  • Upload date:
  • Size: 13.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for mmds-0.0.1.dev20211003212456.tar.gz
Algorithm Hash digest
SHA256 fea5d93ccc2b48d1170ef97cd4183db1389edd897eb0fd6641fcbefae20d41e0
MD5 fb763b9b907aa8bd056cda6517e682cd
BLAKE2b-256 8e3766bedf7e4c11c3789fdf620c4ecc5ef815354713f98e0d3f8e62269c6cc4

See more details on using hashes here.

File details

Details for the file mmds-0.0.1.dev20211003212456-py3-none-any.whl.

File metadata

  • Download URL: mmds-0.0.1.dev20211003212456-py3-none-any.whl
  • Upload date:
  • Size: 16.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for mmds-0.0.1.dev20211003212456-py3-none-any.whl
Algorithm Hash digest
SHA256 15d589b2a603e07518e3cbc5ab0b340e8d0f5a652f80c933cdf70b2cd22d3b0d
MD5 7288ca9656e4bedcd0592776e231b9a0
BLAKE2b-256 091371edb1b1e19da1db6e07b5264349c750932ca41dbf664196d246482429f4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page