Skip to main content

MMDS: A general-purpose multimodal dataset wrapper.

Project description

MMDS: A general-purpose multimodal dataset wrapper

This project is under construction, API may change from time to time.

Installation

Stable (not stable yet though)

pip install mmds

Latest

pip install mmds --pre

Example Usage

# example.py

import timeit
from pathlib import Path
from multiprocessing import Manager

from mmds import MultimodalDataset, MultimodalSample
from mmds.exceptions import PackageNotFoundError
from mmds.modalities.rgbs import RgbsModality
from mmds.modalities.wav import WavModality
from mmds.modalities.mel import MelModality
from mmds.modalities.f0 import F0Modality
from mmds.modalities.ge2e import Ge2eModality
from mmds.utils.spectrogram import LogMelSpectrogram


try:
    import youtube_dl
    import ffmpeg
    import torch
    from torchvision import transforms
except ImportError:
    raise PackageNotFoundError(
        "youtube_dl",
        "ffmpeg-python",
        "torch",
        "torchvision",
        by="example.py",
    )


def download():
    Path("data").mkdir(exist_ok=True)

    ydl_opts = {
        "postprocessors": [
            {
                "key": "FFmpegExtractAudio",
                "preferredcodec": "mp3",
                "preferredquality": "192",
            }
        ],
        "postprocessor_args": ["-ar", "16000"],
        "outtmpl": "data/%(id)s.%(ext)s",
        "keepvideo": True,
    }
    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        ydl.download(["https://www.youtube.com/watch?v=BaW_jenozKc"])

    path = Path("data/BaW_jenozKc")

    if not path.exists():
        path.mkdir(exist_ok=True)

        (
            ffmpeg.input("data/BaW_jenozKc.mp4")
            .filter("fps", fps="25")
            .output("data/BaW_jenozKc/%06d.png", start_number=0)
            .overwrite_output()
            .run(quiet=True)
        )


class MyMultimodalSample(MultimodalSample):
    def generate_info(self):
        wav_modality = self.get_modality_by_name("wav")
        rgbs_modality = self.get_modality_by_name("rgbs")
        return dict(
            t0=0,
            t1=wav_modality.duration / 10,
            original_wav_seconds=wav_modality.duration,
            original_rgbs_seconds=rgbs_modality.duration,
        )


class MyMultimodalDataset(MultimodalDataset):
    Sample = MyMultimodalSample


def main():
    download()

    # optional multiprocessing cache manager
    manager = Manager()

    dataset = MyMultimodalDataset(
        ["BaW_jenozKc"],
        modality_factories=[
            RgbsModality.create_factory(
                name="rgbs",
                root="data",
                suffix="*.png",
                sample_rate=25,
                transform=transforms.Compose(
                    [
                        transforms.Resize((28, 28)),
                        transforms.ToTensor(),
                        transforms.Normalize(0.5, 1),
                    ],
                ),
                aggragate=torch.stack,
                cache=manager.dict(),
            ),
            WavModality.create_factory(
                name="wav",
                root="data",
                suffix=".mp3",
                sample_rate=16_000,
                cache=manager.dict(),
            ),
            MelModality.create_factory(
                name="mel",
                root="data",
                suffix=".mel.npz",
                mel_fn=LogMelSpectrogram(sample_rate=16_000),
                base_modality_name="wav",
                cache=manager.dict(),
            ),
            F0Modality.create_factory(
                name="f0",
                root="data",
                suffix=".f0.npz",
                mel_fn=LogMelSpectrogram(sample_rate=16_000),
                base_modality_name="wav",
                cache=manager.dict(),
            ),
            Ge2eModality.create_factory(
                name="ge2e",
                root="data",
                suffix=".ge2e.npz",
                sample_rate=16_000,
                base_modality_name="wav",
                cache=manager.dict(),
                fetching=False,
            ),
        ],
    )

    # first load
    print(timeit.timeit(lambda: dataset[0], number=1))

    # second load
    print(timeit.timeit(lambda: dataset[0], number=1))

    print(dataset[0]["info"])

    for key, value in dataset[0].items():
        try:
            print(key, value.shape, type(value))
        except:
            pass


if __name__ == "__main__":
    main()

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mmds-0.0.1.dev20211221202826.tar.gz (12.5 kB view details)

Uploaded Source

Built Distribution

mmds-0.0.1.dev20211221202826-py3-none-any.whl (15.9 kB view details)

Uploaded Python 3

File details

Details for the file mmds-0.0.1.dev20211221202826.tar.gz.

File metadata

  • Download URL: mmds-0.0.1.dev20211221202826.tar.gz
  • Upload date:
  • Size: 12.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for mmds-0.0.1.dev20211221202826.tar.gz
Algorithm Hash digest
SHA256 22c420c96d5c778ac902150f581cbda680986e609e6e4bb2c3d5f96ff7f49204
MD5 41113a4f6812de74655aee774cb35c61
BLAKE2b-256 79fe5d3df31468df86df897e9aa6458b0c06671c7bbfe8d4613191ef6bb7311a

See more details on using hashes here.

File details

Details for the file mmds-0.0.1.dev20211221202826-py3-none-any.whl.

File metadata

  • Download URL: mmds-0.0.1.dev20211221202826-py3-none-any.whl
  • Upload date:
  • Size: 15.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for mmds-0.0.1.dev20211221202826-py3-none-any.whl
Algorithm Hash digest
SHA256 83194f45386dbf5759df83ad6f50fac60efc92b3fa3b9ea7f5dd23f06de3ac94
MD5 a5316b3a9de8cf1e21d53451672acbd6
BLAKE2b-256 09095ff125d7f35fed2dd927e87be7aca0265fdb8a01af06b47cf78594dd2a6c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page