Skip to main content

MMDS: A general-purpose multimodal dataset wrapper.

Project description

MMDS: A general-purpose multimodal dataset wrapper

This project is under construction, API may change from time to time.

Installation

Stable (not stable yet though)

pip install mmds

Latest

pip install mmds --pre

Example Usage

# example.py

import timeit
from pathlib import Path
from multiprocessing import Manager

from mmds import MultimodalDataset, MultimodalSample
from mmds.exceptions import PackageNotFoundError
from mmds.modalities.rgbs import RgbsModality
from mmds.modalities.wav import WavModality
from mmds.modalities.mel import MelModality
from mmds.modalities.f0 import F0Modality
from mmds.modalities.ge2e import Ge2eModality
from mmds.utils.spectrogram import LogMelSpectrogram


try:
    import youtube_dl
    import ffmpeg
    import torch
    from torchvision import transforms
except ImportError:
    raise PackageNotFoundError(
        "youtube_dl",
        "ffmpeg-python",
        "torch",
        "torchvision",
        by="example.py",
    )


def download():
    Path("data").mkdir(exist_ok=True)

    ydl_opts = {
        "postprocessors": [
            {
                "key": "FFmpegExtractAudio",
                "preferredcodec": "mp3",
                "preferredquality": "192",
            }
        ],
        "postprocessor_args": ["-ar", "16000"],
        "outtmpl": "data/%(id)s.%(ext)s",
        "keepvideo": True,
    }
    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        ydl.download(["https://www.youtube.com/watch?v=BaW_jenozKc"])

    path = Path("data/BaW_jenozKc")

    if not path.exists():
        path.mkdir(exist_ok=True)

        (
            ffmpeg.input("data/BaW_jenozKc.mp4")
            .filter("fps", fps="25")
            .output("data/BaW_jenozKc/%06d.png", start_number=0)
            .overwrite_output()
            .run(quiet=True)
        )


class MyMultimodalSample(MultimodalSample):
    def generate_info(self):
        wav_modality = self.get_modality_by_name("wav")
        rgbs_modality = self.get_modality_by_name("rgbs")
        return dict(
            t0=0,
            t1=wav_modality.duration / 10,
            original_wav_seconds=wav_modality.duration,
            original_rgbs_seconds=rgbs_modality.duration,
        )


class MyMultimodalDataset(MultimodalDataset):
    Sample = MyMultimodalSample


def main():
    download()

    # optional multiprocessing cache manager
    manager = Manager()

    dataset = MyMultimodalDataset(
        ["BaW_jenozKc"],
        modality_factories=[
            RgbsModality.create_factory(
                name="rgbs",
                root="data",
                suffix="*.png",
                sample_rate=25,
                transform=transforms.Compose(
                    [
                        transforms.Resize((28, 28)),
                        transforms.ToTensor(),
                        transforms.Normalize(0.5, 1),
                    ],
                ),
                aggragate=torch.stack,
                cache=manager.dict(),
            ),
            WavModality.create_factory(
                name="wav",
                root="data",
                suffix=".mp3",
                sample_rate=16_000,
                cache=manager.dict(),
            ),
            MelModality.create_factory(
                name="mel",
                root="data",
                suffix=".mel.npz",
                mel_fn=LogMelSpectrogram(sample_rate=16_000),
                base_modality_name="wav",
                cache=manager.dict(),
            ),
            F0Modality.create_factory(
                name="f0",
                root="data",
                suffix=".f0.npz",
                mel_fn=LogMelSpectrogram(sample_rate=16_000),
                base_modality_name="wav",
                cache=manager.dict(),
            ),
            Ge2eModality.create_factory(
                name="ge2e",
                root="data",
                suffix=".ge2e.npz",
                sample_rate=16_000,
                base_modality_name="wav",
                cache=manager.dict(),
                fetching=False,
            ),
        ],
    )

    # first load
    print(timeit.timeit(lambda: dataset[0], number=1))

    # second load
    print(timeit.timeit(lambda: dataset[0], number=1))

    print(dataset[0]["info"])

    for key, value in dataset[0].items():
        try:
            print(key, value.shape, type(value))
        except:
            pass


if __name__ == "__main__":
    main()

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mmds-0.0.1.dev20211221204032.tar.gz (12.5 kB view details)

Uploaded Source

Built Distribution

mmds-0.0.1.dev20211221204032-py3-none-any.whl (15.9 kB view details)

Uploaded Python 3

File details

Details for the file mmds-0.0.1.dev20211221204032.tar.gz.

File metadata

  • Download URL: mmds-0.0.1.dev20211221204032.tar.gz
  • Upload date:
  • Size: 12.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for mmds-0.0.1.dev20211221204032.tar.gz
Algorithm Hash digest
SHA256 d5380c3a50557bb5bc9ee2a664f4521174aa67319a06a4509f208aeecbe99932
MD5 311ca27ce731f928af6e80827f8b3a21
BLAKE2b-256 d63734f9750e88cff948f47ae2ecaa76ba7d086b1280f1014a52e05e4d723171

See more details on using hashes here.

File details

Details for the file mmds-0.0.1.dev20211221204032-py3-none-any.whl.

File metadata

  • Download URL: mmds-0.0.1.dev20211221204032-py3-none-any.whl
  • Upload date:
  • Size: 15.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for mmds-0.0.1.dev20211221204032-py3-none-any.whl
Algorithm Hash digest
SHA256 ab05bd3c3329ac3d8464119ad3361a7438d840c13de6db07ab045db1e7859732
MD5 14040f84c0998ddbf2e1bd03dd0e0333
BLAKE2b-256 d1326b2fe55d831afd240d0a0c076eb9700c8791fd123c18556d5ac01507187e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page