Skip to main content

Tools to generate and use multi-object datasets

Project description

Multi-object datasets

Tools to generate and use multi-object datasets. The datasets consist of images and a dictionary of labels, where each image is labeled with 1) the number of objects in it and 2) each object's attributes.

Using datasets only requires numpy as datasets are .npz. Generating sprites requires scikit-image. Tools for using the datasets in PyTorch are provided, with usage examples.

Basic usage (pip package)

  1. Either download one of the datasets in generated/, or generate a new one.
  2. Place the .npz dataset in /path/to/data/.
  3. pip install multiobject
  4. Usage in PyTorch:
    from multiobject.pytorch import MultiObjectDataLoader, MultiObjectDataset
    dataset_path = '/path/to/data/some_dataset.npz'
    train_set = MultiObjectDataset(dataset_path, train=True)
    test_set = MultiObjectDataset(dataset_path, train=False)
    train_loader = MultiObjectDataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = MultiObjectDataLoader(test_set, batch_size=test_batch_size)
    

Run demos

conda create --name multiobject python=3.7
conda activate multiobject
pip install -r requirements.txt
CUDA_VISIBLE_DEVICES=0 python demo_vae.py
CUDA_VISIBLE_DEVICES=0 python demo_count.py

Available datasets

Datasets are available as .npz files in ./generated/.

dSprites1

Binary 64x64 RGB images with monochromatic dSprites on a black canvas. Sprites are 18x18 and 7 different colors, and they can overlap (sum and clip).

  • 100k images with 1 sprite per image [10.6 MB]
  • 100k images with 1 sprite per image, larger sprites (max 28x28) [12.4 MB]
  • 100k images with 0, 1, or 2 (uniformly) sprites per image [11 MB]

generated generated_dsprites

Binarized MNIST

Binary 64x64 single-channel images with MNIST digits on a black canvas. Digits are rescaled to 18x18 and binarized, and they can overlap (sum and clip). Only digits from the MNIST training set are used (60k).

  • 100k images with 1 digit per image [4.5 MB]
  • 100k images with 0, 1, or 2 (uniformly) digits per image [4.8 MB]

generated mnist

Generating a new dataset

  1. Clone this repo.

  2. See requirements, or set up a virtual environment as follows:

    conda create --name multiobject python=3.7
    conda activate multiobject
    pip install -r requirements.txt
    
  3. Optional: generate a new type of sprites:

    1. create a file sprites/xyz.py containing a function generate_xyz(), where "xyz" denotes the new sprite type
    2. in generate_dataset.py, add a call to generate_xyz() to generate the correct sprites, and add 'xyz' to the list of supported sprites
  4. Call generate_dataset.py with the desired sprite type as --type argument. Example:

    python generate_dataset.py --type dsprites
    

The sprite attributes are managed automatically when generating a dataset from a set of sprites that have per-sprite labels. However, since they are dataset-specific, they have to be defined when creating the sprites.

Note. For now, the following parameters have to be customized in generate_dataset.py directly:

  • probability distribution over number of objects
  • image size
  • sprite size
  • dataset size
  • whether sprites can overlap

Requirements

To generate datasets:

numpy==1.18.1
matplotlib==3.1.2
scikit_image==0.16.2
tqdm==4.41.1
pillow==7.0.0

To run the examples or use the pytorch tools:

torch==1.4.0
torchvision==0.5.0

Footnotes

1 This is actually an extension of the original dSprites dataset to many objects and to color images.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

multiobject-0.0.3.tar.gz (6.9 kB view details)

Uploaded Source

Built Distribution

multiobject-0.0.3-py3-none-any.whl (8.1 kB view details)

Uploaded Python 3

File details

Details for the file multiobject-0.0.3.tar.gz.

File metadata

  • Download URL: multiobject-0.0.3.tar.gz
  • Upload date:
  • Size: 6.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.6

File hashes

Hashes for multiobject-0.0.3.tar.gz
Algorithm Hash digest
SHA256 4459d8a61e9bc3cae2c2166c0da1d09dc2986dd3605e4eceed690dffdcd56991
MD5 7bfd947cd01081b73fd5fbf2755e9c5c
BLAKE2b-256 6d137608269c70b3ef47d4f5deaab0a0f783319e0db32285971a8828f81b2949

See more details on using hashes here.

File details

Details for the file multiobject-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: multiobject-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 8.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.6

File hashes

Hashes for multiobject-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 0a062b186886fcca12da8fdeacb0535b183efebcac74ddcd082b968f178f93cd
MD5 f0d643b399592a4e4a91c55ee8f3c548
BLAKE2b-256 652f7a029462782459b99895104ffe8a56fa8ceb6a63929ab7b4a30dd8394344

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page