Pytorch implementation of Multi-Object Network(MONet)
Project description
Pytorch MONet implementation
Pytorch implementation of Multi-Object Network(MONet)
How to install
You can install through pip with the following command
pip install monet-pytorch
or clone this repository locally and install with poetry
git clone https://github.com/Michedev/MONet-pytorch
cd MONet-pytorch
poetry install
How to use
The package comes with a set of predefined configurations based on paper specifications, namely monet and monet-iodine (MONet as defined in IODINE paper).
from monet_pytorch import Monet
monet = Monet.from_config(model='monet')
There is also another custom architecture monet-lightweight which has less parameters than the original ones.
Furthermore, the model architecture slightly changes based on the dataset (e.g. U-Net blocks) picked between the ones cited in MONet paper (CLEVR 6, Multidsprites colored on colored, Multidsprited colored on grayscale, Tetrominoes).
from monet_pytorch import Monet
monet = Monet.from_config(model='monet', dataset='tetrominoes')
In alternative, you can set custom dataset parameters through the function arguments
from monet_pytorch import Monet
monet = Monet.from_config(model='monet', dataset_width=48, dataset_height=48, scene_max_objects=6)
Lastly, you can make your custom MONet by input your custom configuration as OmegaConf DictConfig
from monet_pytorch import Monet
custom_monet_config = OmegaConf.create("""
dataset:
width: 44
height: 44
max_num_objects: 10
model: #this config file follows MONet implementation from IODINE paper
_target_: monet_pytorch.model.Monet
height: ${dataset.height}
width: ${dataset.width}
num_slots: ${dataset.max_num_objects}
name: monet-iodine
bg_sigma: 0.32
fg_sigma: 0.1
beta_kl: 0.43
gamma: 0.5
latent_size: 16
input_channels: 3
encoder:
_target_: torch.nn.Sequential
_args_:
- _target_: monet_pytorch.template.sequential_cnn.make_sequential_cnn_from_config
channels: [44, 44, 32, 14]
kernels: 3
strides: 2
paddings: 0
input_channels: 4
batchnorms: true
bn_affines: false
activations: relu
- _target_: torch.nn.Flatten
start_dim: 1
- _target_: torch.nn.Linear
in_features: 256
out_features: 256
- _target_: torch.nn.ReLU
- _target_: torch.nn.Linear
in_features: 256
out_features: ${prod:${model.latent_size},2}
decoder:
_target_: monet_pytorch.template.encoder_decoder.BroadcastDecoderNet
w_broadcast: ${sum:${dataset.width},8}
h_broadcast: ${sum:${dataset.height},8}
net:
_target_: monet_pytorch.template.sequential_cnn.make_sequential_cnn_from_config
input_channels: ${sum:${model.latent_size},2} # latent size + 2
channels: [32, 32, 64, 64, 4] # last is 4 channels because rgb (3) + mask (1)
kernels: [3, 3, 3, 3, 1]
paddings: 0
activations: [relu, relu, relu, relu, null] #null means no activation function no activation
batchnorms: [true, true, true, true, false]
bn_affines: [false, false, false, false, false]
unet:
_target_: monet_pytorch.unet.UNet
input_channels: ${model.input_channels}
num_blocks: 5
filter_start: 16
mlp_size: 128""")
custom_monet: Monet = Monet.from_custom_config(custom_monet_config)
Model performances
This implementation reproduce very closely ARI MONet's values
Special thanks
I would like to thank @apra and @addtt for the help to fix code bugs and to improve model performances
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for monet_pytorch-1.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cc4d306c27586c8eed90361c8119e7ec1d179ed822059c0f871dae3c0f03515a |
|
MD5 | 5f2f0d0360d20c238495da9c35a50af9 |
|
BLAKE2b-256 | 35cc99ee180b0c3d4c365bcd108f88da1dcb2acc30676528434d63cab9d05c47 |